diff --git a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py index 169e02d07..01b8fbad2 100644 --- a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py +++ b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py @@ -31,6 +31,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F return dense_mask +@tilelang.jit(out_idx=[4]) def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): block_M = 64 block_N = 64 @@ -193,9 +194,8 @@ def test_topk_sparse_attention(): x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - # Run Triton kernel - program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) - kernel = tilelang.compile(program, out_idx=[4]) + # Run tilelang kernel + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) tilelang_output = kernel(q, k, v, block_mask) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index 03854b338..ea8e78528 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -19,6 +19,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): accum_dtype = "float" kv_group_num = heads // heads_kv + @tilelang.jit(out_idx=[-1]) def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, max_selected_blocks): shape_q = [batch, heads, dim] @@ -203,7 +204,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): self.block_H = 64 - program = flashattn(batch, heads, heads_kv, dim, dim_v)( + self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( block_N=block_size, block_H=self.block_H, num_split=T.symbolic("num_split"), @@ -212,9 +213,6 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): max_cache_seqlen=T.symbolic("max_cache_seqlen"), max_selected_blocks=T.symbolic("max_selected_blocks")) - self.kernel = tilelang.compile( - program, out_idx=-1, target='cuda', execution_backend="cython") - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -308,7 +306,11 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql is_causal_or_local=True, max_splits=128) - program = flashattn(batch, heads, heads_kv, dim, dim_v)( + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') + Output_partial = torch.empty((batch, heads, num_split, dim_v), + dtype=torch.float32, + device='cuda') + kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( block_N=block_size, block_H=block_H, num_split=T.symbolic("num_split"), @@ -317,14 +319,6 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql max_cache_seqlen=T.symbolic("max_cache_seqlen"), max_selected_blocks=T.symbolic("max_selected_blocks")) - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') - Output_partial = torch.empty((batch, heads, num_split, dim_v), - dtype=torch.float32, - device='cuda') - kernel = tilelang.compile(program, out_idx=-1, target='cuda', execution_backend="cython") - # print(kernel.get_kernel_source()) - - # output = kernel(query, key, value, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial) output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial) return output @@ -458,7 +452,6 @@ def main(batch=8, ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) - # out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size) sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) debug("output", ref, out, atol=1e-3, rtol=1e-3) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index 0c43889f8..0c3c2de45 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -20,6 +20,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): accum_dtype = "float" kv_group_num = heads // heads_kv + @tilelang.jit(out_idx=[-1]) def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks): shape_q = [batch, heads, dim] shape_k = [batch, max_cache_seqlen, heads_kv, dim] @@ -189,7 +190,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): self.block_H = 64 - program = flashattn(batch, heads, heads_kv, dim, dim_v)( + self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( block_N=block_size, block_H=self.block_H, num_split=T.symbolic("num_split"), @@ -198,9 +199,6 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): max_cache_seqlen=T.symbolic("max_cache_seqlen"), num_blocks=T.symbolic("num_blocks")) - self.kernel = tilelang.compile( - program, out_idx=-1, target='cuda', execution_backend="cython") - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -281,7 +279,7 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, is_causal_or_local=True, max_splits=128) - program = flashattn(batch, heads, heads_kv, dim, dim_v)( + kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( block_N=block_size, block_H=block_H, num_split=T.symbolic("num_split"), @@ -293,7 +291,6 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device='cuda') - kernel = tilelang.compile(program, out_idx=-1, target='cuda', execution_backend="cython") # print(kernel.get_kernel_source()) output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) diff --git a/examples/blocksparse_gemm/example_blocksparse_gemm.py b/examples/blocksparse_gemm/example_blocksparse_gemm.py index 9661dd2e9..b7785f61e 100644 --- a/examples/blocksparse_gemm/example_blocksparse_gemm.py +++ b/examples/blocksparse_gemm/example_blocksparse_gemm.py @@ -142,6 +142,7 @@ def kernel(block_M=None, return autotuner.run(warmup=3, rep=20) +@tilelang.jit(out_idx=[-1]) def blocksparse_matmul(M, N, K, @@ -211,10 +212,9 @@ def main(): print(f"Best Kernel Latency: {best_latency:.6f} ms") print(f"Reference Latency: {ref_latency:.6f} ms") else: - func = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K, - DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM, - DEFAULT_ENABLE_RASTERIZATION) - kernel = tilelang.compile(func, out_idx=-1) + kernel = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K, + DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM, + DEFAULT_ENABLE_RASTERIZATION) block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") diff --git a/examples/cast/example_group_per_split_token_cast_to_fp8.py b/examples/cast/example_group_per_split_token_cast_to_fp8.py index 917d1d63f..9b6f10441 100644 --- a/examples/cast/example_group_per_split_token_cast_to_fp8.py +++ b/examples/cast/example_group_per_split_token_cast_to_fp8.py @@ -12,6 +12,7 @@ accum_dtype = "float" +@tilelang.jit(out_idx=[2, 3]) def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): group_size = 128 fp8_min = -448.0 @@ -179,13 +180,7 @@ def main(): print("batch_sizes:", batch_sizes) print("M_max:", M_max) - program = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m) - kernel = tilelang.compile( - program, - out_idx=[2, 3], - target="cuda", - execution_backend="cython", - pass_configs={"tl.disable_tma_lower": True}) + kernel = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m) print(kernel.get_kernel_source()) # profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) diff --git a/examples/cast/example_per_token_cast_to_fp8.py b/examples/cast/example_per_token_cast_to_fp8.py index 2bbb35d42..1a1f47aeb 100644 --- a/examples/cast/example_per_token_cast_to_fp8.py +++ b/examples/cast/example_per_token_cast_to_fp8.py @@ -10,6 +10,7 @@ tilelang.disable_cache() +@tilelang.jit(out_idx=[1, 2]) def per_token_cast_to_fp8(M, N, blk_m): dtype = "float" group_size = 128 @@ -83,13 +84,7 @@ def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def main(): M, N, blk_m = 8192, 8192, 8 - program = per_token_cast_to_fp8(M, N, blk_m) - kernel = tilelang.compile( - program, - out_idx=[1, 2], - target="cuda", - execution_backend="cython", - pass_configs={"tl.disable_tma_lower": True}) + kernel = per_token_cast_to_fp8(M, N, blk_m) print(kernel.get_kernel_source()) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) diff --git a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py index b9b9f30e6..f52b9333e 100644 --- a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py +++ b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py @@ -5,13 +5,14 @@ import torch import tilelang.testing -import tilelang as TL +import tilelang import tilelang.language as T from tilelang.utils.tensor import map_torch_type tilelang.testing.set_random_seed(42) +@tilelang.jit(out_idx=[2]) def tl_gemm( M, N, @@ -147,8 +148,7 @@ def calc_diff(x, y): def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype): - gemm = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype) - kernel = TL.compile(gemm, out_idx=[]) + kernel = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype) src_code = kernel.get_kernel_source() # src_code is the generated cuda source diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py index 95216939b..97fee061f 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py @@ -12,6 +12,7 @@ tilelang.disable_cache() +@tilelang.jit(out_idx=[6]) def flashmla_decode(batch, heads, kv_head_num, @@ -290,9 +291,8 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): BLOCK_H = 64 num_split = 4 - program = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, - num_split) - kernel = tilelang.compile(program, out_idx=[6]) + kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, + num_split) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) input_tensors = profiler._get_inputs() tilelang_output = kernel(*input_tensors) diff --git a/examples/deepseek_mla/benchmark_mla.py b/examples/deepseek_mla/benchmark_mla.py index f08d9dea9..2de1c0f19 100644 --- a/examples/deepseek_mla/benchmark_mla.py +++ b/examples/deepseek_mla/benchmark_mla.py @@ -436,9 +436,8 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) - program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, - num_kv_splits, block_size) - kernel = tilelang.compile(program, out_idx=[8]) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, + num_kv_splits, block_size) def flash_mla_tilelang(): out = kernel( diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index 134de8a1a..4fdc85376 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -9,6 +9,7 @@ import argparse +@tilelang.jit(out_idx=[6]) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) dtype = "float16" @@ -289,8 +290,7 @@ def main(): BLOCK_H = 64 num_split = 1 - program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) - kernel = tilelang.compile(program, out_idx=[6]) + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) latency = profiler.do_bench(warmup=500) @@ -299,4 +299,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/deepseek_mla/example_mla_decode_paged.py b/examples/deepseek_mla/example_mla_decode_paged.py index 2d649099f..e74670d06 100644 --- a/examples/deepseek_mla/example_mla_decode_paged.py +++ b/examples/deepseek_mla/example_mla_decode_paged.py @@ -7,6 +7,7 @@ import math +@tilelang.jit(out_idx=[8]) def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size): scale = (1.0 / (dv + dpe))**0.5 * 1.44269504 # log2(e) @@ -323,9 +324,8 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) - program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, - num_kv_splits, block_size) - kernel = tilelang.compile(program, out_idx=[8]) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, + num_kv_splits, block_size) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) def flash_mla_tilelang(): diff --git a/examples/deepseek_mla/example_mla_decode_persistent.py b/examples/deepseek_mla/example_mla_decode_persistent.py index 90fcae04a..36fcf5ee8 100644 --- a/examples/deepseek_mla/example_mla_decode_persistent.py +++ b/examples/deepseek_mla/example_mla_decode_persistent.py @@ -10,6 +10,7 @@ import argparse +@tilelang.jit(out_idx=[6]) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) dtype = "float16" @@ -209,8 +210,7 @@ def main(): BLOCK_H = 64 num_split = 2 - program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) - kernel = tilelang.compile(program, out_idx=[6]) + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) print(kernel.get_kernel_source()) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) diff --git a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py index 857fa956e..0aeffb153 100644 --- a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py +++ b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py @@ -9,6 +9,7 @@ import argparse +@tilelang.jit(out_idx=[-1]) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) dtype = "float16" @@ -148,9 +149,7 @@ def ref_program(q, q_pe, kv, k_pe): BLOCK_N = 64 BLOCK_H = 64 - program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H) - print(program) - kernel = tilelang.compile(program, out_idx=-1) + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) latency = profiler.do_bench(warmup=500) print(f"Latency: {latency} ms") diff --git a/examples/deepseek_nsa/example_tilelang_nsa_decode.py b/examples/deepseek_nsa/example_tilelang_nsa_decode.py index c40d3d638..882b70827 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_decode.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_decode.py @@ -10,6 +10,7 @@ tilelang.testing.set_random_seed(42) +@tilelang.jit(out_idx=[-1]) def native_sparse_attention( batch, heads, @@ -132,7 +133,7 @@ def main(): B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16 groups = HQ // H SEQ_LEN_Q = 1 - program = native_sparse_attention( + kernel = native_sparse_attention( batch=B, heads=HQ, seq_len=SEQ_LEN, @@ -142,7 +143,6 @@ def main(): selected_blocks=S, ) - kernel = tilelang.compile(program, out_idx=-1) Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py index 2316d1ff5..d23a99e91 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py @@ -10,6 +10,7 @@ tilelang.testing.set_random_seed(0) +@tilelang.jit(out_idx=[-1]) def native_sparse_attention(batch, heads, seq_len, @@ -130,7 +131,7 @@ def native_sparse_attention( def main(): B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 - program = native_sparse_attention( + kernel = native_sparse_attention( batch=B, heads=HQ, seq_len=SEQ_LEN, @@ -141,7 +142,6 @@ def main(): selected_blocks=S, scale=scale, ) - kernel = tilelang.compile(program, out_idx=-1) print(kernel.get_kernel_source()) torch.random.manual_seed(0) Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py index b356ba826..6e357c743 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py @@ -18,6 +18,7 @@ from einops import rearrange +@tilelang.jit def native_sparse_attention_varlen(batch, heads, c_seq_len, @@ -173,7 +174,7 @@ def parallel_nsa_fwd( BS = block_size WS = window_size - program = native_sparse_attention_varlen( + kernel = native_sparse_attention_varlen( batch=batch, heads=HQ, c_seq_len=C_SEQ_LEN, @@ -184,8 +185,6 @@ def parallel_nsa_fwd( selected_blocks=S, ) - kernel = tilelang.compile(program) - o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device) kernel( q.view(C_SEQ_LEN, HQ, D), k.view(C_SEQ_LEN, H, D), v.view(C_SEQ_LEN, H, D), diff --git a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py index 8bdfe3bdb..3fc5f3119 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py @@ -10,6 +10,7 @@ tilelang.testing.set_random_seed(0) +@tilelang.jit(out_idx=[2]) def matmul( M, N, @@ -100,7 +101,7 @@ def run_gemm( num_stages=3, num_threads=128, ): - program = matmul( + kernel = matmul( M, N, K, @@ -114,7 +115,6 @@ def run_gemm( num_threads, ) - kernel = tilelang.compile(program, out_idx=[2]) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) out = profiler.run_once() @@ -437,7 +437,6 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): def main(): test_run_dequantize_gemm() - test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4() if __name__ == "__main__": diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index aa05dc4da..2f56537ad 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -57,6 +57,7 @@ def _convert(val, pos): return new_tensor +@tilelang.jit(out_idx=[1]) def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): num_elems_per_byte = 8 // num_bits storage_dtype = "uint8" @@ -92,7 +93,7 @@ def main( def test_fp4_fp16_convert_close(): N, K = 256, 256 block_N, block_K = 64, 64 - program = test_convert( + kernel = test_convert( N, K, block_N, @@ -100,8 +101,6 @@ def test_fp4_fp16_convert_close(): "float16", ) - kernel = tilelang.compile(program, out_idx=[1]) - B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) tl_out = kernel(B) ref_out = torch_convert(B) @@ -131,6 +130,7 @@ def get_configs(): def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): + @tilelang.jit(out_idx=[2]) def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): num_elems_per_byte = 8 // num_bits storage_dtype = "uint8" @@ -273,10 +273,9 @@ def main(m=256, n=256, k=256, tune=False): total_flops = 2 * m * n * k if (not tune): - program = matmul( + kernel = matmul( m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)( block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) - kernel = tilelang.compile(program, out_idx=[2]) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") diff --git a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py index 5e42869f2..d3e90ec93 100644 --- a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py +++ b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py @@ -7,6 +7,7 @@ _tir_packed_int_to_int_convert,) +@tilelang.jit def dequantize_gemv( M: int, N: int, @@ -173,11 +174,9 @@ def main() -> None: group_size = -1 with_scaling = False - program = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype, - source_format, n_partition, reduce_thread, fast_decoding, trans_A, - trans_B, group_size, with_scaling) - - kernel = tilelang.compile(program) + kernel = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype, + source_format, n_partition, reduce_thread, fast_decoding, trans_A, + trans_B, group_size, with_scaling) storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) num_elems_per_byte = storage_nbit // num_bits diff --git a/examples/dynamic_shape/example_dynamic.py b/examples/dynamic_shape/example_dynamic.py index 8973339b3..70ff984a3 100644 --- a/examples/dynamic_shape/example_dynamic.py +++ b/examples/dynamic_shape/example_dynamic.py @@ -10,6 +10,7 @@ tilelang.disable_cache() +@tilelang.jit(pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8}) def matmul_dynamic_mnk( block_M, block_N, @@ -63,14 +64,8 @@ def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtyp print( f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}" ) - program = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, - accum_dtype, num_stages, threads) - - kernel = tilelang.compile( - program, pass_configs={ - "tl.disable_dynamic_tail_split": True, - "tl.dynamic_alignment": 8 - }) + kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, + accum_dtype, num_stages, threads) import torch if trans_A: diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index 4b091b7f6..82d15cf5a 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -12,6 +12,7 @@ def ref_program(x, y): return x + y +@tilelang.jit(out_idx=[-1]) def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): @T.prim_func @@ -70,8 +71,7 @@ def main(): else: # Default config config = {"block_M": 128, "block_N": 256, "threads": 128} - kernel = tilelang.compile( - elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32"), out_idx=-1) + kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") out = kernel(a, b) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py index edd57d3bc..4b7d70ff3 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd.py +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -68,6 +68,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): dtype = "float16" accum_dtype = "float" + @tilelang.jit(out_idx=[3]) def kernel_func(block_M, block_N, num_stages, threads): @T.macro @@ -243,11 +244,10 @@ def main(batch: int = 1, total_flops *= 0.5 if (not tune): - program = flashattn( + kernel = flashattn( batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)( block_M=64, block_N=64, num_stages=2, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) - kernel = tilelang.compile(program, out_idx=[3]) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) print("All checks pass.") diff --git a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py index 1e2e7e0d1..66309532b 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -35,6 +35,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): dtype = "float16" accum_dtype = "float" + @tilelang.jit(out_idx=[3]) def kernel_func(block_M, block_N, num_stages, threads): @T.macro @@ -217,11 +218,10 @@ def main( total_flops *= 0.5 if (not tune): - program = flashattn( + kernel = flashattn( batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)( block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) - kernel = tilelang.compile(program, out_idx=[3]) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) print("All checks pass.") diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index 228d36614..bbd7abc20 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -35,6 +35,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): dtype = "float16" accum_dtype = "float" + @tilelang.jit(out_idx=[3]) def kernel_func(block_M, block_N, num_stages, threads): @T.macro @@ -200,11 +201,10 @@ def main( total_flops *= 0.5 if (not tune): - program = flashattn( + kernel = flashattn( batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)( block_M=64, block_N=64, num_stages=1, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal) - kernel = tilelang.compile(program, out_idx=[3]) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index f0591d457..e11e855bd 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -35,6 +35,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): dtype = "float16" accum_dtype = "float" + @tilelang.jit(out_idx=[3]) def kernel_func(block_M, block_N, num_stages, threads): @T.macro @@ -205,11 +206,10 @@ def main( total_flops *= 0.5 if (not tune): - program = flashattn( + kernel = flashattn( batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)( block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal) - kernel = tilelang.compile(program, out_idx=[3]) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py index cc7e82376..60765c0c2 100644 --- a/examples/flash_attention/example_mha_fwd_bshd.py +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -33,6 +33,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): dtype = "float16" accum_dtype = "float" + @tilelang.jit(out_idx=[3]) def kernel_func(block_M, block_N, num_stages, threads): @T.macro @@ -194,11 +195,10 @@ def main( total_flops *= 0.5 if (not tune): - program = flashattn( + kernel = flashattn( batch, heads, seq_len, dim, is_causal, tune=tune)( block_M=128, block_N=128, num_stages=1, threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal) - kernel = tilelang.compile(program, out_idx=[3]) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) print("All checks pass.") diff --git a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py index 9fb67b34c..5b0d35f89 100644 --- a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -33,6 +33,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): dtype = "float16" accum_dtype = "float" + @tilelang.jit(out_idx=[3]) def kernel_func(block_M, block_N, num_stages, threads): @T.macro @@ -199,11 +200,10 @@ def main( total_flops *= 0.5 if (not tune): - program = flashattn( + kernel = flashattn( batch, heads, seq_len, dim, is_causal, tune=tune)( block_M=128, block_N=128, num_stages=2, threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal) - kernel = tilelang.compile(program, out_idx=[3]) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) print("All checks pass.") diff --git a/examples/flash_attention/example_mha_fwd_varlen.py b/examples/flash_attention/example_mha_fwd_varlen.py index 204338ae2..80714b404 100644 --- a/examples/flash_attention/example_mha_fwd_varlen.py +++ b/examples/flash_attention/example_mha_fwd_varlen.py @@ -234,6 +234,7 @@ def flashattn(batch_size, UQ, UKV, heads, dim, is_causal): dtype = "float16" accum_dtype = "float" + @tilelang.jit(out_idx=[6]) def kernel_func(block_M, block_N, num_stages, threads): @T.prim_func @@ -402,8 +403,7 @@ def main(batch: int = 2, heads: int = 16, seq_len: int = 256, dim: int = 32): UK = k_unpad.shape[0] # unpadded key length UKV = k_unpad.shape[0] # unpadded query key length - program = flashattn(batch, UQ, UKV, heads, dim, causal) - kernel = tilelang.compile(program, [6]) + kernel = flashattn(batch, UQ, UKV, heads, dim, causal) print(kernel.get_kernel_source()) out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) diff --git a/examples/flash_attention/example_mha_inference.py b/examples/flash_attention/example_mha_inference.py index 0ee812dba..3c0d64585 100644 --- a/examples/flash_attention/example_mha_inference.py +++ b/examples/flash_attention/example_mha_inference.py @@ -8,6 +8,7 @@ num_split = 4 +@tilelang.jit(out_idx=[5]) def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape_q = [batch, seqlen_q, heads, dim] @@ -303,9 +304,8 @@ def main(): total_flops *= 0.5 BLOCK_M = 128 BLOCK_N = 64 # if D_HEAD <= 128 else 32 - program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) + kernel = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) ref_program_processed = partial(ref_program, causal=causal) - kernel = tilelang.compile(program, out_idx=[5]) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) print("All checks passed!") diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py index 9b93f7d3d..cc5ed7358 100644 --- a/examples/flash_decoding/example_gqa_decode.py +++ b/examples/flash_decoding/example_gqa_decode.py @@ -6,6 +6,8 @@ from einops import rearrange, einsum import argparse import itertools +from functools import lru_cache +from typing import Tuple, Dict torch.random.manual_seed(0) @@ -28,6 +30,30 @@ def get_configs(): return configs +@lru_cache(maxsize=1) +def get_heuristic_config() -> Tuple[Dict, int]: + # Get CUDA device properties + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + device = torch.cuda.current_device() + sm_major, sm_minor = torch.cuda.get_device_capability(device) + sm_version = sm_major * 10 + sm_minor + print(f"CUDA device capability: {sm_version}") + if sm_version == 89: + cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=0, threads=128) + else: + cfg = dict(block_N=128, block_H=64, num_split=16, num_stages=2, threads=128) + return cfg, sm_version + + +def get_pass_configs(): + _, sm_version = get_heuristic_config() + if sm_version == 80: + return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True} + else: + return {} + + def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape_q = [batch, heads, dim] @@ -38,6 +64,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): accum_dtype = "float" kv_group_num = heads // groups + @tilelang.jit(out_idx=[6], pass_configs=get_pass_configs()) def kernel_func(block_N, block_H, num_split, num_stages, threads): part_shape = [batch, heads, num_split, dim] valid_block_H = min(block_H, kv_group_num) @@ -457,39 +484,8 @@ def main(batch: int = 1, total_flops = qk_flops + pv_flops if (not tune): - - def get_heuristic_config() -> dict: - # Get CUDA device properties - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is not available") - device = torch.cuda.current_device() - sm_major, sm_minor = torch.cuda.get_device_capability(device) - sm_version = sm_major * 10 + sm_minor - print(f"CUDA device capability: {sm_version}") - if sm_version == 89: - return { - "block_N": 128, - "block_H": 64, - "num_split": 16, - "num_stages": 0, - "threads": 128 - }, sm_version - else: - return { - "block_N": 128, - "block_H": 64, - "num_split": 16, - "num_stages": 2, - "threads": 128 - }, sm_version - config, sm_version = get_heuristic_config() - program = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)(**config) - if sm_version == 90: - kernel = tilelang.compile( - program, out_idx=[6], pass_configs={"tl.disable_tma_lower": True}) - else: - kernel = tilelang.compile(program, out_idx=[6]) + kernel = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)(**config) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16) diff --git a/examples/flash_decoding/example_mha_inference.py b/examples/flash_decoding/example_mha_inference.py index eb8093df4..7dd6f924e 100644 --- a/examples/flash_decoding/example_mha_inference.py +++ b/examples/flash_decoding/example_mha_inference.py @@ -8,6 +8,7 @@ num_split = 4 +@tilelang.jit(out_idx=[5], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}) def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape_q = [batch, seqlen_q, heads, dim] @@ -302,10 +303,8 @@ def main(): total_flops *= 0.5 BLOCK_M = 128 BLOCK_N = 64 # if D_HEAD <= 128 else 32 - program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) + kernel = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) ref_fn = partial(ref_program, causal=causal) - kernel = tilelang.compile( - program, out_idx=[5], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}) print(kernel.get_kernel_source()) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_fn, rtol=0.01, atol=0.01) @@ -320,4 +319,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/gemm/example_gemm.py b/examples/gemm/example_gemm.py index 2e8f3c54a..f40da6c73 100644 --- a/examples/gemm/example_gemm.py +++ b/examples/gemm/example_gemm.py @@ -5,6 +5,7 @@ import tilelang.language as T +@tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func @@ -30,11 +31,7 @@ def gemm( def main(): - func = matmul(1024, 1024, 1024, 128, 128, 32) - - print(func) - - kernel = tilelang.compile(func, out_idx=-1) + kernel = matmul(1024, 1024, 1024, 128, 128, 32) import torch diff --git a/examples/gemm/example_gemm_autotune.py b/examples/gemm/example_gemm_autotune.py index 0b3094c40..c19d36c67 100644 --- a/examples/gemm/example_gemm_autotune.py +++ b/examples/gemm/example_gemm_autotune.py @@ -166,6 +166,7 @@ def get_heuristic_config() -> dict: } +@tl.jit(out_idx=[-1]) def matmul(M, N, K, @@ -219,7 +220,7 @@ def main(m: int = 4096, kernel = result.kernel else: config = get_heuristic_config() - kernel = tl.compile(matmul(M, N, K, **config), out_idx=-1) + kernel = matmul(M, N, K, **config) # benchmark profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py index a89e7a497..4150d5a9d 100644 --- a/examples/gemm/example_gemm_intrinsics.py +++ b/examples/gemm/example_gemm_intrinsics.py @@ -26,6 +26,7 @@ def transform_func(i, j): return T.Layout(shape, transform_func) +@tilelang.jit(out_idx=[2]) @simplify_prim_func def tl_matmul( M, @@ -167,8 +168,7 @@ def ref_program(A, B): def main(): M, N, K = 16384, 16384, 16384 in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32" - matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - kernel = tilelang.compile(matmul, out_idx=[2]) + kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) src_code = kernel.get_kernel_source() # src_code is the generated cuda source assert src_code is not None diff --git a/examples/gemm/example_gemm_persistent.py b/examples/gemm/example_gemm_persistent.py index 11005e7d6..6abfca98d 100644 --- a/examples/gemm/example_gemm_persistent.py +++ b/examples/gemm/example_gemm_persistent.py @@ -7,6 +7,7 @@ import argparse +@tilelang.jit(out_idx=[-1]) def matmul_non_persistent(M, N, K, @@ -44,6 +45,7 @@ def main( return main +@tilelang.jit(out_idx=[-1]) def matmul_persistent(M, N, K, @@ -134,8 +136,7 @@ def main(): threads = 256 num_stages = 3 - persistent_program = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) - persistent_kernel = tilelang.compile(persistent_program, out_idx=-1) + persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) persistent_profiler = persistent_kernel.get_profiler( tensor_supply_type=tilelang.TensorSupplyType.Randn) persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) @@ -144,9 +145,8 @@ def main(): print(f"Persistent GEMM Latency: {persistent_latency} ms") print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops") - non_persistent_program = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, - num_stages) - non_persistent_kernel = tilelang.compile(non_persistent_program, out_idx=-1) + non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, + num_stages) non_persistent_profiler = non_persistent_kernel.get_profiler( tensor_supply_type=tilelang.TensorSupplyType.Randn) non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) diff --git a/examples/gemm/example_gemm_schedule.py b/examples/gemm/example_gemm_schedule.py index cb20b41ea..ff9427d9a 100644 --- a/examples/gemm/example_gemm_schedule.py +++ b/examples/gemm/example_gemm_schedule.py @@ -5,6 +5,7 @@ import tilelang.language as T +@tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func @@ -43,11 +44,7 @@ def gemm_schedule( def main(): - func = matmul(1024, 1024, 1024, 128, 128, 32) - - print(func) - - kernel = tilelang.compile(func, out_idx=-1) + kernel = matmul(1024, 1024, 1024, 128, 128, 32) import torch diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8.py b/examples/gemm_fp8/example_tilelang_gemm_fp8.py index a9a169b61..59e904f67 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8.py @@ -14,6 +14,7 @@ def calc_diff(x, y): return 1 - sim +@tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): @T.prim_func @@ -41,9 +42,8 @@ def gemm_fp8( def test_gemm_fp8(M, N, K, dtype): torch_dtype = map_torch_type(dtype) - func = matmul(M, N, K, 128, 128, 64, dtype) + kernel = matmul(M, N, K, 128, 128, 64, dtype) - kernel = tilelang.compile(func, out_idx=-1) a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) @@ -65,4 +65,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py index 20b2b1453..4bd8a2f57 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py @@ -7,6 +7,7 @@ from tilelang.utils.tensor import map_torch_type +@tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): # for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128. # if block_K < 128, promote after 128/block_K iters. @@ -59,9 +60,7 @@ def calc_diff(x, y): def test_gemm_fp8(M, N, K, dtype): torch_dtype = map_torch_type(dtype) - func = matmul(M, N, K, 128, 128, 64, dtype) - - kernel = tilelang.compile(func, out_idx=-1) + kernel = matmul(M, N, K, 128, 128, 64, dtype) a = torch.rand(M, K, dtype=torch.float16, device='cuda') a = (100 * (2 * a - 1)).to(dtype=torch_dtype) @@ -83,4 +82,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index c3c0b054c..794906d7c 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -31,6 +31,7 @@ def transform_func(i, j): return T.Layout(shape, transform_func) +@tilelang.jit(out_idx=[2]) @simplify_prim_func def tl_matmul( M, @@ -179,8 +180,7 @@ def gemm_fp8_intrinsic( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): - matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - kernel = tilelang.compile(matmul, out_idx=[2]) + kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) src_code = kernel.get_kernel_source() print(src_code) # src_code is the generated cuda source @@ -224,4 +224,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/examples/gemm_splitk/example_tilelang_gemm_splitk.py index f61de6191..80237c1d4 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -5,6 +5,7 @@ import tilelang.language as T +@tilelang.jit def matmul(M, N, K, @@ -62,9 +63,7 @@ def main(): block_K = 32 split_k = 4 - program = matmul(M, N, K, block_M, block_N, block_K, split_k) - - kernel = tilelang.compile(program) + kernel = matmul(M, N, K, block_M, block_N, block_K, split_k) import torch diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py index 7bf53c7cb..c4be9230c 100644 --- a/examples/gemm_streamk/example_tilelang_gemm_streamk.py +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -57,6 +57,7 @@ def cdiv(a, b): sm_patition_factor = max(blocking_tiles // total_sm, 1) +@tilelang.jit def tl_matmul_streamk( M, N, @@ -173,7 +174,7 @@ def main( def main(): - _tl_matmul_streamk = tl_matmul_streamk( + kernel = tl_matmul_streamk( m, n, k, @@ -190,7 +191,6 @@ def main(): 64, ) - kernel = tilelang.compile(_tl_matmul_streamk) print(kernel.get_kernel_source()) b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16) diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py index 2902bbb77..4b06e2055 100644 --- a/examples/gemv/example_gemv.py +++ b/examples/gemv/example_gemv.py @@ -13,6 +13,7 @@ def ref_program(A, B): return A @ B.T +@tl.jit(out_idx=[-1]) def naive_gemv( N: int, K: int, @@ -46,6 +47,7 @@ def main( return main +@tl.jit(out_idx=[-1]) def naive_splitk_gemv( N: int, K: int, @@ -81,6 +83,7 @@ def main( return main +@tl.jit(out_idx=[-1]) def splitk_gemv( N: int, K: int, @@ -120,6 +123,7 @@ def main( return main +@tl.jit(out_idx=[-1]) def splitk_gemv_vectorized( N: int, K: int, @@ -160,6 +164,7 @@ def main( return main +@tl.jit(out_idx=[-1]) def splitk_gemv_vectorized_tvm( N: int, K: int, @@ -292,7 +297,6 @@ def main( def check_correctness_and_bench(kernel, N, K, bench_ref=True): - kernel = tl.compile(kernel, out_idx=-1) profiler = kernel.get_profiler() profiler.assert_allclose(lambda x, y: x @ y.T, atol=1e-2, rtol=1e-2) if bench_ref: @@ -318,7 +322,6 @@ def main(): best_result = get_best_config(N, K) best_config = best_result.config kernel = splitk_gemv_vectorized_tvm(N, K, **best_config) - kernel = tl.compile(kernel, out_idx=-1) profiler = kernel.get_profiler() latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500) print(f"Torch Latency: {latency} ms") diff --git a/examples/grouped_gemm/example_grouped_gemm_bwd.py b/examples/grouped_gemm/example_grouped_gemm_bwd.py index 315a71f1c..7154ab20a 100644 --- a/examples/grouped_gemm/example_grouped_gemm_bwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_bwd.py @@ -10,6 +10,11 @@ tilelang.disable_cache() +@tilelang.jit( + out_idx=[2], pass_configs={ + "tl.disable_tma_lower": True, + "tl.disable_warp_specialized": True + }) def grouped_gemm_fwd(batch_sum, batch_count, K, @@ -106,16 +111,9 @@ def forward(ctx, a, b, batch_sizes): batch_padded_offsets = torch.tensor( batch_padded_offsets_list, device=a.device, dtype=torch.int32) - program = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, - num_stages, threads) + kernel = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, + num_stages, threads) - kernel = tilelang.compile( - program, - out_idx=[2], - pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets) ctx.save_for_backward(a, b, batch_sizes, batch_offsets) ctx.batch_sum = batch_sum @@ -142,15 +140,8 @@ def maybe_contiguous(x): return x A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)] - program = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, - num_stages, threads) - kernel = tilelang.compile( - program, - out_idx=[2], - pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) + kernel = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K, + num_stages, threads) dB = kernel(A, grad_output, batch_sizes, batch_offsets) return None, dB, None @@ -201,6 +192,11 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets +@tilelang.jit( + out_idx=[2], pass_configs={ + "tl.disable_tma_lower": True, + "tl.disable_warp_specialized": True + }) def grouped_gemm_bwd(batch_sum, batch_count, M, diff --git a/examples/grouped_gemm/example_grouped_gemm_fwd.py b/examples/grouped_gemm/example_grouped_gemm_fwd.py index e598506c3..7e97cb16c 100644 --- a/examples/grouped_gemm/example_grouped_gemm_fwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_fwd.py @@ -10,6 +10,11 @@ tilelang.disable_cache() +@tilelang.jit( + out_idx=[2], pass_configs={ + "tl.disable_tma_lower": True, + "tl.disable_warp_specialized": True + }) def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): """ Perform grouped matrix multiplication using PyTorch. @@ -42,6 +47,11 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): return output +@tilelang.jit( + out_idx=[2], pass_configs={ + "tl.disable_tma_lower": True, + "tl.disable_warp_specialized": True + }) def grouped_gemm(batch_sizes_list, K, N, @@ -143,14 +153,7 @@ def run_tilelang_grouped_gemm(batch_sizes_list, profile=False): padding_M = block_M batch_sum = sum(batch_sizes_list) - program = grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, num_stages, threads) - kernel = tilelang.compile( - program, - out_idx=[2], - pass_configs={ - "tl.disable_tma_lower": True, - "tl.disable_warp_specialized": True - }) + kernel = grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, num_stages, threads) # print(kernel.get_kernel_source()) device = torch.device("cuda") diff --git a/examples/hadamard_transform/example_hadamard.py b/examples/hadamard_transform/example_hadamard.py index 2be9a2c19..130434917 100644 --- a/examples/hadamard_transform/example_hadamard.py +++ b/examples/hadamard_transform/example_hadamard.py @@ -16,6 +16,7 @@ def is_pow_of_2(n): return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 +@tilelang.jit(out_idx=[1]) def hadamard(b, n, dtype): assert is_pow_of_2(n), "n must be a power of 2" assert 2 <= n <= 32768, "n must be in [2, 32768]" @@ -145,7 +146,7 @@ def main(): B, D = args.batch, args.dim x = torch.randn((B, D), device='cuda') - kernel = tilelang.compile(hadamard(B, D, 'float32'), out_idx=1) + kernel = hadamard(B, D, 'float32') y = kernel(x) y_ref = ref_program(x) torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py index d03398627..b0db08ed8 100644 --- a/examples/linear_attention/example_linear_attn_bwd.py +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -10,6 +10,7 @@ from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA +@tl.jit(out_idx=[4, 5, 6]) def chunk_linear_attn_bwd_kernel( B, S, @@ -158,8 +159,7 @@ def main(): v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - fn = chunk_linear_attn_bwd_kernel(B, S, H, D, D) - kernel = tl.compile(fn, out_idx=[4, 5, 6], target='cuda') + kernel = chunk_linear_attn_bwd_kernel(B, S, H, D, D) dq, dk, dv = postprocess(*kernel(q, k, v, do)) o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) o_ref.backward(do, retain_graph=True) diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py index e9e96de13..afba81a02 100644 --- a/examples/linear_attention/example_linear_attn_fwd.py +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -10,6 +10,7 @@ from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA +@tl.jit(out_idx=[3, 4]) def chunk_linear_attn_fwd_kernel( B, S, @@ -100,8 +101,7 @@ def main(): k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - fn = chunk_linear_attn_fwd_kernel(B, S, H, D, D) - kernel = tl.compile(fn, out_idx=[3, 4], target='cuda') + kernel = chunk_linear_attn_fwd_kernel(B, S, H, D, D) o, h = postprocess(*kernel(q, k, v)) o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) diff --git a/examples/linear_attention/example_mamba_chunk_scan.py b/examples/linear_attention/example_mamba_chunk_scan.py index 2de348bef..1bc53d767 100644 --- a/examples/linear_attention/example_mamba_chunk_scan.py +++ b/examples/linear_attention/example_mamba_chunk_scan.py @@ -82,6 +82,7 @@ def get_configs(): return configs +@tilelang.jit(out_idx=[7]) def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False): dtype = "float16" accum_dtype = "float" @@ -232,10 +233,9 @@ def kernel(block_M, block_N, block_K, block_Dstate, num_stages, threads): total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate if (not args.tune): - program = chunk_scan_fwd( + kernel = chunk_scan_fwd( batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)( block_M=64, block_N=64, block_K=64, block_Dstate=128, num_stages=2, threads=128) - kernel = tilelang.compile(program, out_idx=[7]) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") diff --git a/examples/linear_attention/example_mamba_chunk_state.py b/examples/linear_attention/example_mamba_chunk_state.py index eaa665031..dd299c3a7 100644 --- a/examples/linear_attention/example_mamba_chunk_state.py +++ b/examples/linear_attention/example_mamba_chunk_state.py @@ -65,6 +65,7 @@ def get_configs(): return configs +@tilelang.jit(out_idx=[4]) def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False): dtype = "float16" accum_dtype = "float" @@ -169,10 +170,9 @@ def kernel(block_M, block_N, block_K, num_stages, threads): total_flops = 2 * batch * seq_len * heads * dim * dstate if (not args.tune): - program = chunk_state_fwd( + kernel = chunk_state_fwd( batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)( block_M=64, block_N=128, block_K=64, num_stages=4, threads=128) - kernel = tilelang.compile(program, out_idx=[4]) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") diff --git a/examples/linear_attention/example_retnet.py b/examples/linear_attention/example_retnet.py index ec84dab6c..0b05ed3da 100644 --- a/examples/linear_attention/example_retnet.py +++ b/examples/linear_attention/example_retnet.py @@ -4,6 +4,7 @@ import tilelang.language as T +@tilelang.jit(out_idx=[4]) def retnet(batch, heads, seq_len, dim_qk, dim_v, block_M, block_N): qk_shape = [batch, seq_len, heads, dim_qk] v_shape = [batch, seq_len, heads, dim_v] @@ -179,8 +180,7 @@ def main( total_flops = 2.0 * BATCH * H * N_CTX * N_CTX * (dim_qk + dim_v) BLOCK_M = 64 BLOCK_N = 64 - program = retnet(BATCH, H, N_CTX, dim_qk, dim_v, BLOCK_M, BLOCK_N) - kernel = tilelang.compile(program, out_idx=[4]) + kernel = retnet(BATCH, H, N_CTX, dim_qk, dim_v, BLOCK_M, BLOCK_N) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) ins = profiler._get_inputs() diff --git a/examples/norm/rms_norm.py b/examples/norm/rms_norm.py index 51af7e803..fd81acda5 100644 --- a/examples/norm/rms_norm.py +++ b/examples/norm/rms_norm.py @@ -36,6 +36,7 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): return main +@tilelang.jit(out_idx=[-1], pass_configs={"tl.disable_tma_lower": True}) def rms_norm(M, N, blk_m): dtype = "float" @@ -67,13 +68,7 @@ def ref_program(x): if __name__ == "__main__": M, N, blk_m, blk_k = 8192, 8192, 1, 512 - program = rms_norm(M, N, blk_m) - kernel = tilelang.compile( - program, - out_idx=-1, - target="cuda", - execution_backend="cython", - pass_configs={"tl.disable_tma_lower": True}) + kernel = rms_norm(M, N, blk_m) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") @@ -81,4 +76,4 @@ def ref_program(x): latency = profiler.do_bench(ref_program, warmup=500) print("Ref: {:.2f} ms".format(latency)) latency = profiler.do_bench(warmup=500) - print("Tile-lang: {:.2f} ms".format(latency)) \ No newline at end of file + print("Tile-lang: {:.2f} ms".format(latency)) diff --git a/examples/online_softmax/online_softmax.py b/examples/online_softmax/online_softmax.py index c88e5cf2d..bf2c5d8ba 100644 --- a/examples/online_softmax/online_softmax.py +++ b/examples/online_softmax/online_softmax.py @@ -7,6 +7,7 @@ from typing import Callable +@tl.jit(out_idx=[1]) def softmax_kernel( M, N, @@ -61,10 +62,9 @@ def main( M = 8192 N = 8192 -fn = softmax_kernel(M, N) +kernel = softmax_kernel(M, N) dtype = torch.float16 X = torch.randn(M, N, dtype=dtype, device="cuda") -kernel = tl.compile(fn, out_idx=[1], target="cuda") Y = kernel(X).to(dtype) Y_ref = X.softmax(dim=1) diff --git a/examples/seer_attention/block_sparse_attn_tilelang.py b/examples/seer_attention/block_sparse_attn_tilelang.py index e3fa7caa3..0bfe5fd74 100644 --- a/examples/seer_attention/block_sparse_attn_tilelang.py +++ b/examples/seer_attention/block_sparse_attn_tilelang.py @@ -31,6 +31,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F return dense_mask +@tilelang.jit(out_idx=[4]) def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_causal): block_M = 64 block_N = 64 @@ -176,10 +177,9 @@ def test_topk_sparse_attention(): x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - # Run Triton kernel - program = blocksparse_flashattn( + # Run tilelang kernel + kernel = blocksparse_flashattn( BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) - kernel = tilelang.compile(program, out_idx=[4]) print(kernel.get_kernel_source()) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) @@ -226,10 +226,8 @@ def test_topk_sparse_attention_qlen_lt_klen(): x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - program = blocksparse_flashattn( + kernel = blocksparse_flashattn( BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) - print(program) - kernel = tilelang.compile(program, out_idx=[4]) print(kernel.get_kernel_source()) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8)) @@ -267,4 +265,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/warp_specialize/example_warp_specialize_flashmla.py b/examples/warp_specialize/example_warp_specialize_flashmla.py index 8c8ae9422..11427a3a3 100644 --- a/examples/warp_specialize/example_warp_specialize_flashmla.py +++ b/examples/warp_specialize/example_warp_specialize_flashmla.py @@ -9,6 +9,7 @@ from einops import rearrange, einsum +@tilelang.jit(out_idx=[6]) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) dtype = "float16" @@ -174,8 +175,7 @@ def main(): BLOCK_H = 64 num_split = 1 - program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) - kernel = tilelang.compile(program, out_idx=[6]) + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) print(kernel.get_kernel_source()) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) diff --git a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py index 3e345d614..11519b6b5 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py @@ -6,6 +6,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit +@tilelang.jit(out_idx=[2]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): num_stages = 2 @@ -59,19 +60,10 @@ def main(): block_M = 128 block_N = 128 block_K = 64 - # 1. Define the kernel (matmul) and compile/lower it into an executable module - func = matmul(M, N, K, block_M, block_N, block_K) + jit_kernel = matmul(M, N, K, block_M, block_N, block_K) - # 2. Compile the kernel into a torch function - # out_idx specifies the index of the output buffer in the argument list - # if out_idx is specified, the tensor will be created during runtime - # target currently can be "cuda" or "hip" or "cpu". - jit_kernel = tilelang.compile(func, out_idx=[2]) - - # 3. Test the kernel in Python with PyTorch data import torch - # Create random input tensors on the GPU a = torch.randn(M, K, device="cuda", dtype=torch.float16) b = torch.randn(K, N, device="cuda", dtype=torch.float16) diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py index 81dd3ec17..1c0513f44 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py @@ -6,6 +6,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit +@tilelang.jit(out_idx=[2]) def matmul_warp_specialize_copy_0_gemm_1(M, N, K, @@ -58,19 +59,8 @@ def main(): block_N = 128 block_K = 64 - # 1. Define the kernel (matmul) and compile/lower it into an executable module - func = matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K) + jit_kernel = matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K) - # 2. Compile the kernel into a torch function - # out_idx specifies the index of the output buffer in the argument list - # if out_idx is specified, the tensor will be created during runtime - # target currently can be "cuda" or "hip" or "cpu". - jit_kernel = tilelang.compile( - func, - out_idx=[2], - ) - - # 3. Test the kernel in Python with PyTorch data import torch # Create random input tensors on the GPU diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py index ea048623a..dbd45e90e 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py @@ -6,6 +6,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit +@tilelang.jit(out_idx=[2]) def matmul_warp_specialize_copy_1_gemm_0(M, N, K, @@ -58,22 +59,10 @@ def main(): block_N = 128 block_K = 64 - # 1. Define the kernel (matmul) and compile/lower it into an executable module - func = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K) + jit_kernel = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K) - # 2. Compile the kernel into a torch function - # out_idx specifies the index of the output buffer in the argument list - # if out_idx is specified, the tensor will be created during runtime - # target currently can be "cuda" or "hip" or "cpu". - jit_kernel = tilelang.compile( - func, - out_idx=[2], - ) - - # 3. Test the kernel in Python with PyTorch data import torch - # Create random input tensors on the GPU a = torch.randn(M, K, device="cuda", dtype=torch.float16) b = torch.randn(K, N, device="cuda", dtype=torch.float16) diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py index b6051844d..aa9faf9a9 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py @@ -8,6 +8,12 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit +@tilelang.jit( + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + # tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) def matmul_warp_specialize_copy_1_gemm_0(M, N, K, @@ -61,21 +67,7 @@ def main(): block_N = 128 block_K = 64 - # 1. Define the kernel (matmul) and compile/lower it into an executable module - func = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K) - # print(func.script()) - - # 2. Compile the kernel into a torch function - # out_idx specifies the index of the output buffer in the argument list - # if out_idx is specified, the tensor will be created during runtime - # target currently can be "cuda" or "hip" or "cpu". - jit_kernel = tilelang.compile( - func, - out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - # tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) + jit_kernel = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K) print(jit_kernel.get_kernel_source()) # 3. Test the kernel in Python with PyTorch data import torch diff --git a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py index de7c8d906..629651714 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py @@ -6,6 +6,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit +@tilelang.jit(out_idx=[2]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func @@ -50,15 +51,10 @@ def main(): block_M = 128 block_N = 128 block_K = 64 - # 1. Define the kernel (matmul) and compile/lower it into an executable module - func = matmul(M, N, K, block_M, block_N, block_K) - # 2. Compile the kernel into a torch function - # out_idx specifies the index of the output buffer in the argument list - # if out_idx is specified, the tensor will be created during runtime - # target currently can be "cuda" or "hip" or "cpu". + jit_kernel = matmul(M, N, K, block_M, block_N, block_K) + tilelang.disable_cache() - jit_kernel = tilelang.compile(func, out_idx=[2]) # 3. Test the kernel in Python with PyTorch data import torch