From 586ac70f8640dd481530f915406e9a6e97706c68 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 8 Jul 2025 04:41:27 +0000 Subject: [PATCH 1/7] [Refactor] Update tilelang kernel functions and remove unused imports - Refactored the `flashattn_fwd`, `flashattn_bwd_preprocess`, and `flashattn_bwd_postprocess` functions to utilize direct kernel calls instead of cached versions, improving clarity and performance. - Added `@tilelang.jit` decorators with specified output indices to enhance kernel compilation. - Removed unused import of `cached` from `tilelang`, streamlining the code. - Commented out the main testing function call in `test_tilelang_kernel_mha_bwd.py` for potential future use. --- .../kernel/test_tilelang_kernel_mha_bwd.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/testing/python/kernel/test_tilelang_kernel_mha_bwd.py b/testing/python/kernel/test_tilelang_kernel_mha_bwd.py index 0120e912a..b5c228311 100644 --- a/testing/python/kernel/test_tilelang_kernel_mha_bwd.py +++ b/testing/python/kernel/test_tilelang_kernel_mha_bwd.py @@ -4,7 +4,6 @@ import torch import torch.nn.functional as F import tilelang -from tilelang import cached import tilelang.language as T import tilelang.testing @@ -12,6 +11,9 @@ tilelang.testing.set_random_seed(42) +@tilelang.jit( + out_idx=[3, 4], +) def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] @@ -80,7 +82,9 @@ def flash_fwd( return flash_fwd - +@tilelang.jit( + out_idx=[2], +) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -115,7 +119,9 @@ def make_dq_layout(dQ): return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) - +@tilelang.jit( + out_idx=[1], +) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -138,10 +144,8 @@ def flash_bwd_post( @tilelang.jit( - out_idx=[7, 8], - pass_configs={ - tilelang.PassConfigKey.TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS: True, - }) + out_idx=[7, 8] +) def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): sm_scale = (1.0 / dim)**0.5 scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) @@ -164,10 +168,6 @@ def flash_bwd( with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=32) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype) - # should not store K to local if dim is large - # K_local = T.alloc_fragment([block_M, dim], dtype) - # K_local_T = T.alloc_fragment([block_M, dim], dtype) - # V_local = T.alloc_fragment([block_M, dim], dtype) q = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_M, dim], dtype) qkT = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -185,7 +185,6 @@ def flash_bwd( T.annotate_layout({ dQ: make_dq_layout(dQ), - K_shared: tilelang.layout.make_swizzled_layout(K_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), }) @@ -240,8 +239,8 @@ def forward(ctx, q, k, v, causal): BATCH, N_CTX, H, D_HEAD = q.shape block_M = 64 block_N = 64 if D_HEAD <= 128 else 32 - mod = cached(flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N), [3, 4]) - o, lse = mod(q, k, v) + kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + o, lse = kernel(q, k, v) ctx.save_for_backward(q, k, v, o, lse) ctx.causal = causal return o @@ -259,13 +258,13 @@ def maybe_contiguous(x): do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] block_M = 128 block_N = 128 if D_HEAD <= 64 else 32 - mod_prep = cached(flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD), [2]) - mod_post = cached(flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD), [1]) - delta = mod_prep(o, do) - mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N) + kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N) + delta = kernel_prep(o, do) dq = torch.zeros_like(q, dtype=torch.float32) - dk, dv = mod(q, k, v, do, lse, delta, dq) - dq = mod_post(dq) + dk, dv = kernel(q, k, v, do, lse, delta, dq) + dq = kernel_post(dq) return dq, dk, dv, None @@ -315,4 +314,5 @@ def test_mha_bwd(): if __name__ == "__main__": - tilelang.testing.main() + # tilelang.testing.main() + assert_mha_equal(8, 32, 256, 64, False) From 3e59572f02ca58f3def8847fe8a8caf12bcc08fd Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 8 Jul 2025 07:09:55 +0000 Subject: [PATCH 2/7] [Refactor] Simplify configuration generation in benchmark and example scripts - Refactored the `get_configs` functions in multiple benchmark and example scripts to utilize a dictionary-based approach for parameter configuration, improving readability and maintainability. - Updated the `flashattn` and `chunk_scan_fwd` functions to directly accept configuration parameters, enhancing flexibility in kernel tuning. - Removed redundant code and streamlined the configuration generation process across various files, ensuring consistency in how configurations are defined and utilized. --- benchmark/matmul/benchmark_matmul.py | 42 +- .../matmul/benchmark_matmul_intrinsic.py | 58 +-- benchmark/matmul_fp8/benchmark_matmul.py | 41 +- docs/deeplearning_operators/gemv.md | 4 - .../flash_attention/example_gqa_fwd_bshd.py | 264 +++++----- .../example_gqa_fwd_bshd_wgmma_pipelined.py | 310 ++++++------ .../flash_attention/example_mha_fwd_bhsd.py | 289 +++++------ .../example_mha_fwd_bhsd_wgmma_pipelined.py | 299 ++++++----- .../flash_attention/example_mha_fwd_bshd.py | 250 ++++----- .../example_mha_fwd_bshd_wgmma_pipelined.py | 293 ++++++----- examples/flash_decoding/example_gqa_decode.py | 473 ++++++++---------- examples/gemv/example_gemv.py | 15 +- .../example_mamba_chunk_scan.py | 283 +++++------ .../example_mamba_chunk_state.py | 189 ++++--- .../kernel/test_tilelang_kernel_mha_bwd.py | 18 +- tilelang/autotuner/__init__.py | 33 ++ tilelang/autotuner/param.py | 12 +- 17 files changed, 1361 insertions(+), 1512 deletions(-) diff --git a/benchmark/matmul/benchmark_matmul.py b/benchmark/matmul/benchmark_matmul.py index c0f2c7583..aa98cae12 100644 --- a/benchmark/matmul/benchmark_matmul.py +++ b/benchmark/matmul/benchmark_matmul.py @@ -89,36 +89,18 @@ def get_configs(M, N, K, with_roller=False): for config in configs: print(config) else: - - block_M = [64, 128, 256] - block_N = [64, 128, 256] - block_K = [32, 64] - num_stages = [0, 1, 2, 3] - thread_num = [128, 256] - policy = [T.GemmWarpPolicy.Square] - enable_rasterization = [True, False] - _configs = list( - itertools.product( - block_M, - block_N, - block_K, - num_stages, - thread_num, - policy, - enable_rasterization, - )) - - configs = [ - { - "block_M": c[0], - "block_N": c[1], - "block_K": c[2], - "num_stages": c[3], - "thread_num": c[4], - "policy": c[5], - "enable_rasteration": c[6], # keep param name for backward-compat - } for c in _configs - ] + iter_params = dict( + block_M=[64, 128, 256], + block_N=[64, 128, 256], + block_K=[32, 64], + num_stages=[0, 1, 2, 3], + thread_num=[128, 256], + policy=[T.GemmWarpPolicy.Square], + enable_rasteration=[True, False], + ) + return [{ + k: v for k, v in zip(iter_params, values) + } for values in itertools.product(*iter_params.values())] return configs diff --git a/benchmark/matmul/benchmark_matmul_intrinsic.py b/benchmark/matmul/benchmark_matmul_intrinsic.py index 92fcfcfe7..57c164239 100644 --- a/benchmark/matmul/benchmark_matmul_intrinsic.py +++ b/benchmark/matmul/benchmark_matmul_intrinsic.py @@ -221,51 +221,39 @@ def get_configs(M, N, K, with_roller=False): print(config) else: - block_rows_warps = [1, 2, 4] - block_col_warps = [1, 2, 4] - warp_row_tiles = [16, 32, 64, 128] - warp_col_tiles = [16, 32, 64, 128] - chunk = [32, 64, 128, 256] - stage = [0, 2] - enable_rasteration = [True, False] - _configs = list( - itertools.product(block_rows_warps, block_col_warps, warp_row_tiles, warp_col_tiles, - chunk, stage, enable_rasteration)) - configs = [{ - "block_row_warps": c[0], - "block_col_warps": c[1], - "warp_row_tiles": c[2], - "warp_col_tiles": c[3], - "chunk": c[4], - "stage": c[5], - "enable_rasteration": c[6], - } for c in _configs] + iter_params = dict( + block_row_warps=[1, 2, 4], + block_col_warps=[1, 2, 4], + warp_row_tiles=[16, 32, 64, 128], + warp_col_tiles=[16, 32, 64, 128], + chunk=[32, 64, 128, 256], + stage=[0, 2], + enable_rasteration=[True, False], + ) + return [{ + k: v for k, v in zip(iter_params, values) + } for values in itertools.product(*iter_params.values())] return configs -def matmul(M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_roller=False): +def matmul( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_roller=False, +): """Create an autotuned tensor core matrix multiplication kernel.""" @autotune( configs=get_configs(M, N, K, with_roller), - keys=[ - "block_row_warps", - "block_col_warps", - "warp_row_tiles", - "warp_col_tiles", - "chunk", - "enable_rasteration", - "stage", - ], warmup=3, rep=5, + ref_prog=ref_program, + skip_check=True, ) @tl.jit(out_idx=[2],) def kernel( diff --git a/benchmark/matmul_fp8/benchmark_matmul.py b/benchmark/matmul_fp8/benchmark_matmul.py index 4971ddd44..8235ef1cd 100644 --- a/benchmark/matmul_fp8/benchmark_matmul.py +++ b/benchmark/matmul_fp8/benchmark_matmul.py @@ -90,36 +90,19 @@ def get_configs(M, N, K, with_roller=False): for config in configs: print(config) else: + iter_params = dict( + block_M=[64, 128, 256], + block_N=[64, 128, 256], + block_K=[64, 128], + num_stages=[0, 1, 2, 3], + thread_num=[128, 256], + policy=[T.GemmWarpPolicy.Square], + enable_rasteration=[True, False], + ) + return [{ + k: v for k, v in zip(iter_params, values) + } for values in itertools.product(*iter_params.values())] - block_M = [64, 128, 256] - block_N = [64, 128, 256] - block_K = [64, 128] - num_stages = [0, 1, 2, 3] - thread_num = [128, 256] - policy = [T.GemmWarpPolicy.Square] - enable_rasterization = [True, False] - _configs = list( - itertools.product( - block_M, - block_N, - block_K, - num_stages, - thread_num, - policy, - enable_rasterization, - )) - - configs = [ - { - "block_M": c[0], - "block_N": c[1], - "block_K": c[2], - "num_stages": c[3], - "thread_num": c[4], - "policy": c[5], - "enable_rasteration": c[6], # keep param name for backward-compat - } for c in _configs - ] return configs diff --git a/docs/deeplearning_operators/gemv.md b/docs/deeplearning_operators/gemv.md index fe949841c..0ceafe7ed 100644 --- a/docs/deeplearning_operators/gemv.md +++ b/docs/deeplearning_operators/gemv.md @@ -337,10 +337,6 @@ def get_best_config(N, K): @autotune( configs=get_configs(), - keys=[ - "BLOCK_N", - "reduce_threads", - ], warmup=3, rep=20, ) diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py index 4b7d70ff3..b9dee8fcb 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd.py +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -60,7 +60,18 @@ def get_configs(user_config=None): return valid_configs -def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[3]) +def flashattn(batch, + heads, + seq_len, + dim, + is_causal, + groups=1, + block_M=64, + block_N=64, + num_stages=0, + threads=128): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] @@ -68,142 +79,119 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): dtype = "float16" accum_dtype = "float" - @tilelang.jit(out_idx=[3]) - def kernel_func(block_M, block_N, num_stages, threads): + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, + -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), + @T.macro + def Softmax( acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - - @T.prim_func - def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - O_shared = T.alloc_shared([block_M, dim], dtype) - acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) - acc_o = T.alloc_fragment([block_M, dim], accum_dtype) - scores_max = T.alloc_fragment([block_M], accum_dtype) - scores_max_prev = T.alloc_fragment([block_M], accum_dtype) - scores_scale = T.alloc_fragment([block_M], accum_dtype) - scores_sum = T.alloc_fragment([block_M], accum_dtype) - logsum = T.alloc_fragment([block_M], accum_dtype) - - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) - - for k in T.Pipelined(loop_range, num_stages=num_stages): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) - - return main + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv( + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) - if tune: - - @autotune( - configs=get_configs(), - keys=["block_M", "block_N", "num_stages", "threads"], - warmup=10, - rep=10) - @tilelang.jit(out_idx=[3]) - def kernel(block_M=None, block_N=None, num_stages=None, threads=None): - return kernel_func(block_M, block_N, num_stages, threads) - - return kernel() - else: - - def kernel(block_M, block_N, num_stages, threads): - return kernel_func(block_M, block_N, num_stages, threads) + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, + logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) - return kernel + return main def ref_program(Q, K, V, is_causal, groups=1): @@ -245,8 +233,16 @@ def main(batch: int = 1, if (not tune): kernel = flashattn( - batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)( - block_M=64, block_N=64, num_stages=2, threads=128) + batch, + heads, + seq_len, + dim, + is_causal, + groups=groups, + block_M=64, + block_N=64, + num_stages=2, + threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -258,10 +254,10 @@ def main(batch: int = 1, print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=tune) - best_latency = best_result.latency - best_config = best_result.config - ref_latency = best_result.ref_latency + kernel = flashattn(batch, heads, seq_len, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency print(f"Best latency: {best_latency}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best config: {best_config}") diff --git a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py index 66309532b..abd8e05e9 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -12,22 +12,33 @@ def get_configs(): - block_M = [128] - block_N = [128] - num_stages = [2] - threads = [256] - _configs = list(itertools.product(block_M, block_N, num_stages, threads)) - - configs = [{ - 'block_M': c[0], - 'block_N': c[1], - 'num_stages': c[2], - 'threads': c[3] - } for c in _configs] - return configs - - -def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): + iter_params = dict( + block_M=[128], + block_N=[128], + num_stages=[2], + threads=[256], + ) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune( + configs=get_configs(), + warmup=10, + rep=10, +) +@tilelang.jit(out_idx=[3]) +def flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + groups=1, + block_M=64, + block_N=64, + num_stages=0, + threads=128, +): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] @@ -35,147 +46,124 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): dtype = "float16" accum_dtype = "float" - @tilelang.jit(out_idx=[3]) - def kernel_func(block_M, block_N, num_stages, threads): - - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, + -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - - @T.prim_func - def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - O_shared = T.alloc_shared([block_M, dim], dtype) - acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) - acc_o = T.alloc_fragment([block_M, dim], accum_dtype) - scores_max = T.alloc_fragment([block_M], accum_dtype) - scores_max_prev = T.alloc_fragment([block_M], accum_dtype) - scores_scale = T.alloc_fragment([block_M], accum_dtype) - scores_sum = T.alloc_fragment([block_M], accum_dtype) - logsum = T.alloc_fragment([block_M], accum_dtype) - - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) - - for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) - - return main - - if tune: - - @autotune( - configs=get_configs(), - keys=["block_M", "block_N", "num_stages", "threads"], - warmup=10, - rep=10) - @tilelang.jit(out_idx=[3]) - def kernel(block_M=None, block_N=None, num_stages=None, threads=None): - return kernel_func(block_M, block_N, num_stages, threads) - - return kernel() - else: - def kernel(block_M, block_N, num_stages, threads): - return kernel_func(block_M, block_N, num_stages, threads) + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv( + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + + for k in T.Pipelined( + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, + logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) - return kernel + return main def ref_program(Q, K, V, is_causal, groups=1): @@ -219,8 +207,16 @@ def main( if (not tune): kernel = flashattn( - batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)( - block_M=128, block_N=128, num_stages=2, threads=256) + batch, + heads, + seq_len, + dim, + is_causal, + groups=groups, + block_M=128, + block_N=128, + num_stages=2, + threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -232,10 +228,10 @@ def main( print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=tune) - best_latency = best_result.latency - best_config = best_result.config - ref_latency = best_result.ref_latency + kernel = flashattn(batch, heads, seq_len, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency print(f"Best latency: {best_latency}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best config: {best_config}") diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index bbd7abc20..53a4ecfb5 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -9,166 +9,147 @@ import itertools import argparse from functools import partial -from tilelang import jit def get_configs(): - block_M = [128] - block_N = [128] - num_stages = [2] - threads = [256] - _configs = list(itertools.product(block_M, block_N, num_stages, threads)) - - configs = [{ - 'block_M': c[0], - 'block_N': c[1], - 'num_stages': c[2], - 'threads': c[3] - } for c in _configs] - return configs - - -def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[3]) +def flashattn(batch, + heads, + seq_q, + seq_kv, + dim, + is_causal, + block_M=64, + block_N=64, + num_stages=1, + threads=128): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] dtype = "float16" accum_dtype = "float" - @tilelang.jit(out_idx=[3]) - def kernel_func(block_M, block_N, num_stages, threads): - - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + past_len = seq_kv - seq_q + T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - past_len = seq_kv - seq_q - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - q_idx = bx * block_M + i + past_len - k_idx = k * block_N + j - acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - - @T.prim_func - def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - ): - with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - O_shared = T.alloc_shared([block_M, dim], dtype) - acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) - acc_o = T.alloc_fragment([block_M, dim], accum_dtype) - scores_max = T.alloc_fragment([block_M], accum_dtype) - scores_max_prev = T.alloc_fragment([block_M], accum_dtype) - scores_scale = T.alloc_fragment([block_M], accum_dtype) - scores_sum = T.alloc_fragment([block_M], accum_dtype) - logsum = T.alloc_fragment([block_M], accum_dtype) - - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - loop_range = ( - T.min(T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) - - for k in T.Pipelined(loop_range, num_stages=num_stages): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) - - return main - - if tune: - - @autotune(configs=get_configs(), warmup=10, rep=10) - @jit(out_idx=[3]) - def kernel(block_M=None, block_N=None, num_stages=None, threads=None): - return kernel_func(block_M, block_N, num_stages, threads) - - return kernel() - else: + loop_range = ( + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv( + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) - def kernel(block_M, block_N, num_stages, threads): - return kernel_func(block_M, block_N, num_stages, threads) + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, + logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) - return kernel + return main def ref_program(Q, K, V, is_causal): @@ -202,8 +183,16 @@ def main( if (not tune): kernel = flashattn( - batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)( - block_M=64, block_N=64, num_stages=1, threads=128) + batch, + heads, + seq_q, + seq_kv, + dim, + is_causal, + block_M=64, + block_N=64, + num_stages=1, + threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() @@ -216,10 +205,10 @@ def main( print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune) - best_latency = best_result.latency - best_config = best_result.config - ref_latency = best_result.ref_latency + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency print(f"Best latency: {best_latency}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best config: {best_config}") diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index e11e855bd..a6a93d914 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -9,171 +9,152 @@ import itertools import argparse from functools import partial -from tilelang import jit def get_configs(): - block_M = [128] - block_N = [128] - num_stages = [2] - threads = [256] - _configs = list(itertools.product(block_M, block_N, num_stages, threads)) - - configs = [{ - 'block_M': c[0], - 'block_N': c[1], - 'num_stages': c[2], - 'threads': c[3] - } for c in _configs] - return configs - - -def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[3]) +def flashattn(batch, + heads, + seq_q, + seq_kv, + dim, + is_causal, + block_M=128, + block_N=128, + num_stages=2, + threads=256): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] dtype = "float16" accum_dtype = "float" - @tilelang.jit(out_idx=[3]) - def kernel_func(block_M, block_N, num_stages, threads): - - @T.macro - def MMA0( - K: T.Tensor(kv_shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + past_len = seq_kv - seq_q + T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - past_len = seq_kv - seq_q - T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - q_idx = bx * block_M + i + past_len - k_idx = k * block_N + j - acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(kv_shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): + loop_range = ( + T.min(T.ceildiv(seq_kv, block_N), T.ceildiv( + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + + for k in T.Pipelined( + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, + logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - - @T.prim_func - def main( - Q: T.Tensor(q_shape, dtype), - K: T.Tensor(kv_shape, dtype), - V: T.Tensor(kv_shape, dtype), - Output: T.Tensor(q_shape, dtype), - ): - with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - O_shared = T.alloc_shared([block_M, dim], dtype) - acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) - acc_o = T.alloc_fragment([block_M, dim], accum_dtype) - scores_max = T.alloc_fragment([block_M], accum_dtype) - scores_max_prev = T.alloc_fragment([block_M], accum_dtype) - scores_scale = T.alloc_fragment([block_M], accum_dtype) - scores_sum = T.alloc_fragment([block_M], accum_dtype) - logsum = T.alloc_fragment([block_M], accum_dtype) - - T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - loop_range = ( - T.min(T.ceildiv(seq_kv, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) - - for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) - - return main - - if tune: - - @autotune(configs=get_configs(), warmup=10, rep=10) - @jit(out_idx=[3]) - def kernel(block_M=None, block_N=None, num_stages=None, threads=None): - return kernel_func(block_M, block_N, num_stages, threads) - - return kernel() - else: - - def kernel(block_M, block_N, num_stages, threads): - return kernel_func(block_M, block_N, num_stages, threads) + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) - return kernel + return main def ref_program(Q, K, V, is_causal): @@ -207,8 +188,16 @@ def main( if (not tune): kernel = flashattn( - batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)( - block_M=128, block_N=128, num_stages=2, threads=256) + batch, + heads, + seq_q, + seq_kv, + dim, + is_causal, + block_M=128, + block_N=128, + num_stages=2, + threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() @@ -221,10 +210,10 @@ def main( print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune) - best_latency = best_result.latency - best_config = best_result.config - ref_latency = best_result.ref_latency + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency print(f"Best latency: {best_latency}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best config: {best_config}") diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py index 60765c0c2..aaf711559 100644 --- a/examples/flash_attention/example_mha_fwd_bshd.py +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -12,159 +12,131 @@ def get_configs(): - block_M = [64] - block_N = [64] - num_stages = [1] - threads = [128] - _configs = list(itertools.product(block_M, block_N, num_stages, threads)) - - configs = [{ - 'block_M': c[0], - 'block_N': c[1], - 'num_stages': c[2], - 'threads': c[3] - } for c in _configs] - return configs + iter_params = dict(block_M=[64], block_N=[64], num_stages=[1], threads=[128]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[3]) def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" - @tilelang.jit(out_idx=[3]) - def kernel_func(block_M, block_N, num_stages, threads): - - @T.macro - def MMA0( - K: T.Tensor(shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), + @T.macro + def MMA0( + K: T.Tensor(shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, + -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - - @T.prim_func - def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), - ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - O_shared = T.alloc_shared([block_M, dim], dtype) - acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) - acc_o = T.alloc_fragment([block_M, dim], accum_dtype) - scores_max = T.alloc_fragment([block_M], accum_dtype) - scores_max_prev = T.alloc_fragment([block_M], accum_dtype) - scores_scale = T.alloc_fragment([block_M], accum_dtype) - scores_sum = T.alloc_fragment([block_M], accum_dtype) - logsum = T.alloc_fragment([block_M], accum_dtype) - - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) - - for k in T.Pipelined(loop_range, num_stages=num_stages): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) - - return main - - if tune: - - @autotune(configs=get_configs(), warmup=10, rep=10) - @tilelang.jit(out_idx=[3]) - def kernel(block_M=None, block_N=None, num_stages=None, threads=None): - return kernel_func(block_M, block_N, num_stages, threads) - - return kernel() - else: - def kernel(block_M, block_N, num_stages, threads): - return kernel_func(block_M, block_N, num_stages, threads) + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv( + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, + logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) - return kernel + return main def ref_program(Q, K, V, is_causal): diff --git a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py index 5b0d35f89..62da012b1 100644 --- a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -12,164 +12,144 @@ def get_configs(): - block_M = [128] - block_N = [128] - num_stages = [2] - threads = [256] - _configs = list(itertools.product(block_M, block_N, num_stages, threads)) - - configs = [{ - 'block_M': c[0], - 'block_N': c[1], - 'num_stages': c[2], - 'threads': c[3] - } for c in _configs] - return configs - - -def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[3]) +def flashattn(batch, + heads, + seq_len, + dim, + is_causal, + block_M=128, + block_N=128, + num_stages=2, + threads=256): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" - @tilelang.jit(out_idx=[3]) - def kernel_func(block_M, block_N, num_stages, threads): - - @T.macro - def MMA0( - K: T.Tensor(shape, dtype), - Q_shared: T.SharedBuffer([block_M, dim], dtype), - K_shared: T.SharedBuffer([block_N, dim], dtype), + @T.macro + def MMA0( + K: T.Tensor(shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, + -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(shape, dtype), + V_shared: T.SharedBuffer([block_M, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - k: T.int32, - bx: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, - -T.infinity(acc_s.dtype)) - else: - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def MMA1( - V: T.Tensor(shape, dtype), - V_shared: T.SharedBuffer([block_M, dim], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - k: T.int32, - by: T.int32, - bz: T.int32, - ): - T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - @T.macro - def Softmax( - acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), - acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), - scores_max: T.FragmentBuffer([block_M], accum_dtype), - scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - scores_sum: T.FragmentBuffer([block_M], accum_dtype), - logsum: T.FragmentBuffer([block_M], accum_dtype), - ): - T.copy(scores_max, scores_max_prev) + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # To do causal softmax, we need to set the scores_max to 0 if it is -inf - # This process is called Check_inf in FlashAttention3 code, and it only need to be done - # in the first ceil_div(kBlockM, kBlockN) steps. - # for i in T.Parallel(block_M): - # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - # max * log_2(e)) This allows the compiler to use the ffma - # instruction instead of fadd and fmul separately. - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - @T.macro - def Rescale( - acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), - scores_scale: T.FragmentBuffer([block_M], accum_dtype), - ): - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - - @T.prim_func - def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), - ): - with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - O_shared = T.alloc_shared([block_M, dim], dtype) - acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) - acc_o = T.alloc_fragment([block_M, dim], accum_dtype) - scores_max = T.alloc_fragment([block_M], accum_dtype) - scores_max_prev = T.alloc_fragment([block_M], accum_dtype) - scores_scale = T.alloc_fragment([block_M], accum_dtype) - scores_sum = T.alloc_fragment([block_M], accum_dtype) - logsum = T.alloc_fragment([block_M], accum_dtype) - - T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv( - (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) - - for k in T.Pipelined( - loop_range, - num_stages=num_stages, - order=[-1, 0, 3, 1, -1, 2], - stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): - MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) - Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, - scores_sum, logsum) - Rescale(acc_o, scores_scale) - MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) - - return main - - if tune: - - @autotune(configs=get_configs(), warmup=10, rep=10) - @tilelang.jit(out_idx=[3]) - def kernel(block_M=None, block_N=None, num_stages=None, threads=None): - return kernel_func(block_M, block_N, num_stages, threads) - - return kernel() - else: - def kernel(block_M, block_N, num_stages, threads): - return kernel_func(block_M, block_N, num_stages, threads) + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv( + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + + for k in T.Pipelined( + loop_range, + num_stages=num_stages, + order=[-1, 0, 3, 1, -1, 2], + stage=[-1, 0, 0, 1, -1, 1], + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, + logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) - return kernel + return main def ref_program(Q, K, V, is_causal): @@ -201,8 +181,15 @@ def main( if (not tune): kernel = flashattn( - batch, heads, seq_len, dim, is_causal, tune=tune)( - block_M=128, block_N=128, num_stages=2, threads=256) + batch, + heads, + seq_len, + dim, + is_causal, + block_M=128, + block_N=128, + num_stages=2, + threads=256) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) @@ -214,10 +201,10 @@ def main( print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=tune) - best_latency = best_result.latency - best_config = best_result.config - ref_latency = best_result.ref_latency + kernel = flashattn(batch, heads, seq_len, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency print(f"Best latency: {best_latency}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best config: {best_config}") diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py index cc5ed7358..649a1ab84 100644 --- a/examples/flash_decoding/example_gqa_decode.py +++ b/examples/flash_decoding/example_gqa_decode.py @@ -54,7 +54,10 @@ def get_pass_configs(): return {} -def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs()) +def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, + threads): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape_q = [batch, heads, dim] shape_k = [batch, seqlen_kv, groups, dim] @@ -64,260 +67,224 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): accum_dtype = "float" kv_group_num = heads // groups - @tilelang.jit(out_idx=[6], pass_configs=get_pass_configs()) - def kernel_func(block_N, block_H, num_split, num_stages, threads): - part_shape = [batch, heads, num_split, dim] - valid_block_H = min(block_H, kv_group_num) - valid_block_N = min(block_N, seqlen_kv // num_split) - - @T.macro - def flash_attn( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - Output: T.Tensor([batch, heads, dim], dtype), - ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - O_shared = T.alloc_shared([valid_block_H, dim], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - mask_local = T.alloc_fragment([block_N], "uint8") - acc_o = T.alloc_fragment([block_H, dim], accum_dtype) - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - - bid = bx - hid = by - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) + part_shape = [batch, heads, num_split, dim] + valid_block_H = min(block_H, kv_group_num) + valid_block_N = min(block_N, seqlen_kv // num_split) + + @T.macro + def flash_attn( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + Output: T.Tensor([batch, heads, dim], dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([valid_block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + mask_local = T.alloc_fragment([block_N], "uint8") + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + bid = bx + hid = by + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared) + T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], + -T.infinity(accum_dtype)) + T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) - - loop_range = T.ceildiv((seqlen_kv // num_split), block_N) - for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared) - T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local) - T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], - -T.infinity(accum_dtype)) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim): - acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - for i, j in T.Parallel(block_H, dim): - acc_o[i, j] /= logsum[i] + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) - - @T.macro - def flash_attn_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - ): - with T.Kernel( - batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - O_shared = T.alloc_shared([valid_block_H, dim], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - mask_local = T.alloc_fragment([block_N], "uint8") - acc_o = T.alloc_fragment([block_H, dim], accum_dtype) - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - - bid = bx - hid = by - sid = bz - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - loop_range = T.ceildiv((seqlen_kv // num_split), block_N) - T.fill(K_shared, 0) - for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy( - K[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + - (k + 1) * valid_block_N, cur_kv_head, :], K_shared) - T.copy( - mask[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + - (k + 1) * valid_block_N, cur_kv_head], mask_local) - T.clear(acc_s) - T.gemm( - Q_shared, - K_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullRow) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else( - (mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j], - -T.infinity(accum_dtype)) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim): - acc_o[i, j] *= scores_scale[i] - T.copy( - V[bid, (seqlen_kv // num_split) * sid + - k * valid_block_N:(seqlen_kv // num_split) * sid + - (k + 1) * valid_block_N, cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(block_H, dim): - acc_o[i, j] /= logsum[i] + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(acc_o[:valid_block_H, :], O_shared) + T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) + + @T.macro + def flash_attn_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([valid_block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + mask_local = T.alloc_fragment([block_N], "uint8") + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + T.fill(K_shared, 0) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy( + K[bid, (seqlen_kv // num_split) * sid + + k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, :], K_shared) + T.copy( + mask[bid, (seqlen_kv // num_split) * sid + + k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head], mask_local) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, + j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), + acc_s[i, j], -T.infinity(accum_dtype)) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] - T.copy(acc_o[:valid_block_H, :], O_shared) - T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H, - sid, :]) - - @T.macro - def combine( - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), - ): - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim], dtype) - o_accum_local = T.alloc_fragment([dim], accum_dtype) - lse_local = T.alloc_fragment([num_split, 128], dtype) - lse_local_split = T.alloc_local([1], accum_dtype) - lse_logsum_local = T.alloc_local([1], accum_dtype) - lse_max_local = T.alloc_fragment([128], accum_dtype) - scale_local = T.alloc_local([1], accum_dtype) - - T.annotate_layout({ - lse_logsum_local: - T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), - lse_max_local: - T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i), - # lse_local: (local_id, thread_id) - lse_local: - T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)), - }) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - for k, j in T.Parallel(num_split, 128): - lse_local[k, j] = glse[bz, by, k] - T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) - for k in T.Pipelined(num_split, num_stages=1): - lse_local_split[0] = glse[bz, by, k] - lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) - lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] - for k in T.serial(num_split): - for i in T.Parallel(dim): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split[0] = glse[bz, by, k] - scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) - for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.copy( + V[bid, (seqlen_kv // num_split) * sid + + k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, + cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + T.copy(acc_o[:valid_block_H, :], O_shared) + T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H, + sid, :]) + + @T.macro + def combine( + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local = T.alloc_fragment([num_split, 128], dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_fragment([128], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout({ + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i), + # lse_local: (local_id, thread_id) + lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)), + }) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + for k, j in T.Parallel(num_split, 128): + lse_local[k, j] = glse[bz, by, k] + T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) for i in T.Parallel(dim): - Output[bz, by, i] = o_accum_local[i] - - @T.prim_func - def flashattn_gqa_decode_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), - ): - flash_attn_split(Q, K, V, mask, glse, Output_partial) - combine(glse, Output_partial, Output) - - @T.prim_func - def flashattn_gqa_decode_no_split( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), - Output: T.Tensor(shape_o, dtype), - ): - flash_attn(Q, K, V, mask, Output) - - if num_split > 1: - return flashattn_gqa_decode_split - else: - return flashattn_gqa_decode_no_split - - if tune: - - @autotune(configs=get_configs(), warmup=10, rep=10) - @jit( - out_idx=[6], - supply_type=tilelang.TensorSupplyType.Auto, - ref_prog=ref_program, - max_mismatched_ratio=0.05) - def kernel(block_N=None, block_H=None, num_split=None, num_stages=None, threads=None): - return kernel_func(block_N, block_H, num_split, num_stages, threads) - - return kernel() + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def flashattn_gqa_decode_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), + ): + flash_attn_split(Q, K, V, mask, glse, Output_partial) + combine(glse, Output_partial, Output) + + @T.prim_func + def flashattn_gqa_decode_no_split( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), + glse: T.Tensor([batch, heads, num_split], dtype), + Output_partial: T.Tensor(part_shape, dtype), + Output: T.Tensor(shape_o, dtype), + ): + flash_attn(Q, K, V, mask, Output) + + if num_split > 1: + return flashattn_gqa_decode_split else: - - def kernel(block_N, block_H, num_split, num_stages, threads): - return kernel_func(block_N, block_H, num_split, num_stages, threads) - - return kernel + return flashattn_gqa_decode_no_split def ref_program(query, key, value, mask, glse, Output_partial): @@ -485,7 +452,7 @@ def main(batch: int = 1, if (not tune): config, sm_version = get_heuristic_config() - kernel = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)(**config) + kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16) @@ -513,10 +480,10 @@ def main(batch: int = 1, print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune) - best_latency = best_result.latency - best_config = best_result.config - ref_latency = best_result.ref_latency + kernel = flashattn(batch, heads, groups, kv_seqlen, dim) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency print(f"Best latency: {best_latency}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best config: {best_config}") diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py index 4b06e2055..b532620d4 100644 --- a/examples/gemv/example_gemv.py +++ b/examples/gemv/example_gemv.py @@ -221,17 +221,10 @@ def main( def get_best_config(N, K): def get_configs(): - BLOCK_N = [2, 4, 8, 32, 64, 128] - reduce_threads = [4, 8, 32] - _configs = list(itertools.product( - BLOCK_N, - reduce_threads, - )) - configs = [{ - "BLOCK_N": c[0], - "reduce_threads": c[1], - } for c in _configs] - return configs + iter_params = dict(BLOCK_N=[2, 4, 8, 32, 64, 128], reduce_threads=[4, 8, 32]) + return [ + dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values()) + ] @autotune( configs=get_configs(), diff --git a/examples/linear_attention/example_mamba_chunk_scan.py b/examples/linear_attention/example_mamba_chunk_scan.py index 1bc53d767..515298fef 100644 --- a/examples/linear_attention/example_mamba_chunk_scan.py +++ b/examples/linear_attention/example_mamba_chunk_scan.py @@ -64,158 +64,139 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): def get_configs(): - block_M = [64, 128, 256] - block_N = [32, 64] - block_K = [64, 128, 256] - block_Dstate = [128] - num_stages = [1, 2, 3, 4, 5] - _configs = list(itertools.product(block_M, block_N, block_K, block_Dstate, num_stages)) - - configs = [{ - 'block_M': c[0], - 'block_N': c[1], - 'block_K': c[2], - 'block_Dstate': c[3], - 'num_stages': c[4], - 'threads': c[0] * 2 - } for c in _configs] - return configs + iter_params = dict( + block_M=[64, 128, 256], + block_N=[32, 64], + block_K=[64, 128, 256], + block_Dstate=[128], + num_stages=[1, 2, 3, 4, 5]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] +@autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[7]) -def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False): +def chunk_scan_fwd(batch, + seqlen, + chunk_size, + ngroups, + nheads, + headdim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128): dtype = "float16" accum_dtype = "float" nchunks = T.ceildiv(seqlen, chunk_size) p = 1.44269504 - def kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads): - - @T.prim_func - def main(cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), - x: T.Tensor((batch, seqlen, nheads, headdim), dtype), dt: T.Tensor( - (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor( - (batch, nheads, nchunks, chunk_size), dtype), C: T.Tensor( - (batch, seqlen, ngroups, dstate), dtype), prev_states: T.Tensor( - (batch, nchunks, nheads, headdim, dstate), dtype), D: T.Tensor( - (nheads), dtype), Output: T.Tensor( - (batch, seqlen, nheads, headdim), dtype)): - with T.Kernel( - nheads, - T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), - batch * nchunks, - threads=threads) as (bz, bx, by): - acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) - acc_o_shared = T.alloc_shared((block_M, block_N), dtype) - cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") - cb_local = T.alloc_fragment((block_M, block_K), dtype) - dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared") - dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype) - dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype) - dt_shared = T.alloc_shared((block_K), dtype, scope="shared") - dt_local = T.alloc_fragment((block_K), accum_dtype) - x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn") - dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared") - scale_m_local = T.alloc_fragment((block_M), accum_dtype) - C_shared = T.alloc_shared((block_M, block_Dstate), dtype) - prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype) - D_local = T.alloc_fragment((1), accum_dtype) - x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn") - x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - batch_idx = by % batch - chunk_idx = by // batch - # m: chunk_size - # n : headdim - m_idx = bx // T.ceildiv(headdim, block_N) - n_idx = bx % T.ceildiv(headdim, block_N) - - T.annotate_layout({ - acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), - cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), - x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) - }) - - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], - dA_cs_m_shared) - T.copy(dA_cs_m_shared, dA_cs_m_local) - T.clear(acc_o) - - for i in T.Parallel(block_M): - scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) - T.copy( - C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) - T.copy( - prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, - 0:block_Dstate], prev_state_shared) - T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) - for i, j in T.Parallel(block_M, block_N): - acc_o[i, j] *= scale_m_local[i] - - loop_range = T.ceildiv((m_idx + 1) * block_M, block_K) - - for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy( - cb[batch_idx, chunk_idx, bz // (nheads // ngroups), - m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], - cb_shared) - T.copy(cb_shared, cb_local) - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], - dA_cs_k_shared) - T.copy(dA_cs_k_shared, dA_cs_k_local) - for i, j in T.Parallel(block_M, block_K): - cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - - dA_cs_k_local[j] * p) - T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) - T.copy(dt_shared, dt_local) - for i, j in T.Parallel(block_M, block_K): - cb_local[i, j] *= dt_local[j] - for i, j in T.Parallel(block_M, block_K): - cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, - cb_local[i, j], 0) - T.copy( - x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) - T.gemm(cb_local, x_shared, acc_o) - - D_local[0] = D[bz] + @T.prim_func + def main(cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), x: T.Tensor( + (batch, seqlen, nheads, headdim), dtype), dt: T.Tensor( + (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor( + (batch, nheads, nchunks, chunk_size), dtype), + C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), prev_states: T.Tensor( + (batch, nchunks, nheads, headdim, dstate), dtype), D: T.Tensor( + (nheads), dtype), Output: T.Tensor((batch, seqlen, nheads, headdim), dtype)): + with T.Kernel( + nheads, + T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), + batch * nchunks, + threads=threads) as (bz, bx, by): + acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) + acc_o_shared = T.alloc_shared((block_M, block_N), dtype) + cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") + cb_local = T.alloc_fragment((block_M, block_K), dtype) + dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared") + dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype) + dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype) + dt_shared = T.alloc_shared((block_K), dtype, scope="shared") + dt_local = T.alloc_fragment((block_K), accum_dtype) + x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn") + dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared") + scale_m_local = T.alloc_fragment((block_M), accum_dtype) + C_shared = T.alloc_shared((block_M, block_Dstate), dtype) + prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype) + D_local = T.alloc_fragment((1), accum_dtype) + x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn") + x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + batch_idx = by % batch + chunk_idx = by // batch + # m: chunk_size + # n : headdim + m_idx = bx // T.ceildiv(headdim, block_N) + n_idx = bx % T.ceildiv(headdim, block_N) + + T.annotate_layout({ + acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), + cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), + x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) + }) + + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], + dA_cs_m_shared) + T.copy(dA_cs_m_shared, dA_cs_m_local) + T.clear(acc_o) + + for i in T.Parallel(block_M): + scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) + T.copy( + C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + + (m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) + T.copy( + prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, + 0:block_Dstate], prev_state_shared) + T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) + for i, j in T.Parallel(block_M, block_N): + acc_o[i, j] *= scale_m_local[i] + + loop_range = T.ceildiv((m_idx + 1) * block_M, block_K) + + for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + - (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], - x_residual_shared) - T.copy(x_residual_shared, x_residual_local) - for i, j in T.Parallel(block_M, block_N): - acc_o[i, j] += x_residual_local[i, j] * D_local[0] - - T.copy(acc_o, acc_o_shared) + cb[batch_idx, chunk_idx, bz // (nheads // ngroups), + m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], + cb_shared) + T.copy(cb_shared, cb_local) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], + dA_cs_k_shared) + T.copy(dA_cs_k_shared, dA_cs_k_local) + for i, j in T.Parallel(block_M, block_K): + cb_local[i, + j] = cb_local[i, + j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) + T.copy(dt_shared, dt_local) + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] *= dt_local[j] + for i, j in T.Parallel(block_M, block_K): + cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, + cb_local[i, j], 0) T.copy( - acc_o_shared, - Output[batch_idx, chunk_idx * chunk_size + - m_idx * block_M:chunk_idx * chunk_size + (m_idx + 1) * block_M, bz, - n_idx * block_N:(n_idx + 1) * block_N]) + x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + + (k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) + T.gemm(cb_local, x_shared, acc_o) - return main - - if tune: - - @autotune(configs=get_configs(), warmup=10, rep=10) - @tilelang.jit(out_idx=[7]) - def kernel(block_M=None, - block_N=None, - block_K=None, - block_Dstate=None, - num_stages=None, - threads=None): - return kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads) - - return kernel() - else: + D_local[0] = D[bz] + T.copy( + x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + + (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], + x_residual_shared) + T.copy(x_residual_shared, x_residual_local) + for i, j in T.Parallel(block_M, block_N): + acc_o[i, j] += x_residual_local[i, j] * D_local[0] - def kernel(block_M, block_N, block_K, block_Dstate, num_stages, threads): - return kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads) + T.copy(acc_o, acc_o_shared) + T.copy( + acc_o_shared, + Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + + (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N]) - return kernel + return main if __name__ == "__main__": @@ -234,8 +215,19 @@ def kernel(block_M, block_N, block_K, block_Dstate, num_stages, threads): if (not args.tune): kernel = chunk_scan_fwd( - batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)( - block_M=64, block_N=64, block_K=64, block_Dstate=128, num_stages=2, threads=128) + batch, + seq_len, + chunk_size, + groups, + heads, + dim, + dstate, + block_M=64, + block_N=64, + block_K=64, + block_Dstate=128, + num_stages=2, + threads=128) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") @@ -246,11 +238,10 @@ def kernel(block_M, block_N, block_K, block_Dstate, num_stages, threads): print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = chunk_scan_fwd( - batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune) - best_latency = best_result.latency - best_config = best_result.config - ref_latency = best_result.ref_latency + kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency print(f"Best latency: {best_latency}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best config: {best_config}") diff --git a/examples/linear_attention/example_mamba_chunk_state.py b/examples/linear_attention/example_mamba_chunk_state.py index dd299c3a7..b8b509101 100644 --- a/examples/linear_attention/example_mamba_chunk_state.py +++ b/examples/linear_attention/example_mamba_chunk_state.py @@ -49,110 +49,94 @@ def ref_program(B, x, dt, dA_cumsum): def get_configs(): - block_M = [64, 128] - block_N = [32, 64, 128] - block_K = [32, 64] - num_stages = [1, 2, 3, 4, 5] - _configs = list(itertools.product(block_M, block_N, block_K, num_stages)) - - configs = [{ - 'block_M': c[0], - 'block_N': c[1], - 'block_K': c[2], - 'num_stages': c[3], - 'threads': c[0] * 2 - } for c in _configs] - return configs + iter_params = dict( + block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] +@autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[4]) -def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False): +def chunk_state_fwd(batch, + seqlen, + chunk_size, + ngroups, + nheads, + headdim, + dstate, + block_M=64, + block_N=64, + block_K=64, + num_stages=2, + threads=128): dtype = "float16" accum_dtype = "float" nchunks = T.ceildiv(seqlen, chunk_size) p = 1.44269504 - def kernel_func(block_M, block_N, block_K, num_stages, threads): - - @T.prim_func - def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor( - (batch, seqlen, nheads, headdim), dtype), dt: T.Tensor( - (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor( - (batch, nheads, nchunks, chunk_size), dtype), Output: T.Tensor( - (batch, nchunks, nheads, headdim, dstate), dtype)): - with T.Kernel( - nheads, - T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), - batch * nchunks, - threads=threads) as (bz, bx, by): - x_shared = T.alloc_shared((block_K, block_M), dtype) - x_local = T.alloc_fragment((block_K, block_M), dtype) - xt_local = T.alloc_fragment((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - dt_shared = T.alloc_shared((block_K), dtype) - dA_cumsum_shared = T.alloc_shared((block_K), dtype) - acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) - acc_o_shared = T.alloc_shared((block_M, block_N), dtype) - scale = T.alloc_fragment((block_K), accum_dtype) - dA_cs_last = T.alloc_fragment((1), accum_dtype) - dA_cumsum_local = T.alloc_fragment((block_K), accum_dtype) - dt_local = T.alloc_fragment((block_K), accum_dtype) - - loop_range = T.ceildiv(chunk_size, block_K) - - batch_idx = by % batch - chunk_idx = by // batch - m_idx = bx // T.ceildiv(dstate, block_N) - n_idx = bx % T.ceildiv(dstate, block_N) - - T.annotate_layout({ - x_shared: tilelang.layout.make_swizzled_layout(x_shared), - acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared) - }) - - dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1] - T.clear(acc_o) - for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy( - x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz, m_idx * block_M:(m_idx + 1) * block_M], x_shared) - T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], - dA_cumsum_shared) - T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) - T.copy(dA_cumsum_shared, dA_cumsum_local) - T.copy(dt_shared, dt_local) - for i in T.Parallel(block_K): - scale[i] = T.exp2(dA_cs_last[0] * p - dA_cumsum_local[i] * p) * dt_local[i] - T.copy(x_shared, x_local) - for i, j in T.Parallel(block_M, block_K): - xt_local[i, j] = x_local[j, i] * scale[j] - T.copy( - B[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + - (k + 1) * block_K, bz // (nheads // ngroups), - n_idx * block_N:(n_idx + 1) * block_N], B_shared) - T.gemm(xt_local, B_shared, acc_o) - T.copy(acc_o, acc_o_shared) + @T.prim_func + def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor( + (batch, seqlen, nheads, headdim), dtype), dt: T.Tensor( + (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor( + (batch, nheads, nchunks, chunk_size), dtype), Output: T.Tensor( + (batch, nchunks, nheads, headdim, dstate), dtype)): + with T.Kernel( + nheads, + T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N), + batch * nchunks, + threads=threads) as (bz, bx, by): + x_shared = T.alloc_shared((block_K, block_M), dtype) + x_local = T.alloc_fragment((block_K, block_M), dtype) + xt_local = T.alloc_fragment((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + dt_shared = T.alloc_shared((block_K), dtype) + dA_cumsum_shared = T.alloc_shared((block_K), dtype) + acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) + acc_o_shared = T.alloc_shared((block_M, block_N), dtype) + scale = T.alloc_fragment((block_K), accum_dtype) + dA_cs_last = T.alloc_fragment((1), accum_dtype) + dA_cumsum_local = T.alloc_fragment((block_K), accum_dtype) + dt_local = T.alloc_fragment((block_K), accum_dtype) + + loop_range = T.ceildiv(chunk_size, block_K) + + batch_idx = by % batch + chunk_idx = by // batch + m_idx = bx // T.ceildiv(dstate, block_N) + n_idx = bx % T.ceildiv(dstate, block_N) + + T.annotate_layout({ + x_shared: tilelang.layout.make_swizzled_layout(x_shared), + acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared) + }) + + dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1] + T.clear(acc_o) + for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( - acc_o_shared, - Output[batch_idx, chunk_idx, bz, m_idx * block_M:(m_idx + 1) * block_M, - n_idx * block_N:(n_idx + 1) * block_N]) - - return main - - if tune: - - @autotune(configs=get_configs(), warmup=10, rep=10) - @tilelang.jit(out_idx=[4]) - def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None): - return kernel_func(block_M, block_N, block_K, num_stages, threads) - - return kernel() - else: - - def kernel(block_M, block_N, block_K, num_stages, threads): - return kernel_func(block_M, block_N, block_K, num_stages, threads) + x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + + (k + 1) * block_K, bz, m_idx * block_M:(m_idx + 1) * block_M], x_shared) + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], + dA_cumsum_shared) + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) + T.copy(dA_cumsum_shared, dA_cumsum_local) + T.copy(dt_shared, dt_local) + for i in T.Parallel(block_K): + scale[i] = T.exp2(dA_cs_last[0] * p - dA_cumsum_local[i] * p) * dt_local[i] + T.copy(x_shared, x_local) + for i, j in T.Parallel(block_M, block_K): + xt_local[i, j] = x_local[j, i] * scale[j] + T.copy( + B[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + + (k + 1) * block_K, bz // (nheads // ngroups), + n_idx * block_N:(n_idx + 1) * block_N], B_shared) + T.gemm(xt_local, B_shared, acc_o) + T.copy(acc_o, acc_o_shared) + T.copy( + acc_o_shared, + Output[batch_idx, chunk_idx, bz, m_idx * block_M:(m_idx + 1) * block_M, + n_idx * block_N:(n_idx + 1) * block_N]) - return kernel + return main if __name__ == "__main__": @@ -171,8 +155,18 @@ def kernel(block_M, block_N, block_K, num_stages, threads): if (not args.tune): kernel = chunk_state_fwd( - batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)( - block_M=64, block_N=128, block_K=64, num_stages=4, threads=128) + batch, + seq_len, + chunk_size, + groups, + heads, + dim, + dstate, + block_M=64, + block_N=128, + block_K=64, + num_stages=4, + threads=128) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") @@ -183,8 +177,7 @@ def kernel(block_M, block_N, block_K, num_stages, threads): print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = chunk_state_fwd( - batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune) + best_result = chunk_state_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate) best_latency = best_result.latency best_config = best_result.config ref_latency = best_result.ref_latency diff --git a/testing/python/kernel/test_tilelang_kernel_mha_bwd.py b/testing/python/kernel/test_tilelang_kernel_mha_bwd.py index b5c228311..8d0b1156f 100644 --- a/testing/python/kernel/test_tilelang_kernel_mha_bwd.py +++ b/testing/python/kernel/test_tilelang_kernel_mha_bwd.py @@ -11,9 +11,7 @@ tilelang.testing.set_random_seed(42) -@tilelang.jit( - out_idx=[3, 4], -) +@tilelang.jit(out_idx=[3, 4],) def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] @@ -82,9 +80,8 @@ def flash_fwd( return flash_fwd -@tilelang.jit( - out_idx=[2], -) + +@tilelang.jit(out_idx=[2],) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -119,9 +116,8 @@ def make_dq_layout(dQ): return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) -@tilelang.jit( - out_idx=[1], -) + +@tilelang.jit(out_idx=[1],) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): dtype = "float16" accum_dtype = "float" @@ -143,9 +139,7 @@ def flash_bwd_post( return flash_bwd_post -@tilelang.jit( - out_idx=[7, 8] -) +@tilelang.jit(out_idx=[7, 8]) def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): sm_scale = (1.0 / dim)**0.5 scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) diff --git a/tilelang/autotuner/__init__.py b/tilelang/autotuner/__init__.py index faa46309d..e570a25c5 100644 --- a/tilelang/autotuner/__init__.py +++ b/tilelang/autotuner/__init__.py @@ -394,6 +394,29 @@ def shape_equal(a, b): raise ValueError(f"Unused keys in config: {unused_keys}") config_args.append(new_kwargs) + if len(config_args) == 0: + raise ValueError("No configurations to tune, please check your `@autotune` decorator") + + # check if the tunable arguments has been tuned. + # get the back config argument + top_config, *rest = config_args + key_args_tuple, key_kwargs_tuple = self._kernel_parameters + tunable_arguments = [key for key, _ in top_config.items()] + + # Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple + if any(key in top_config for key, _ in key_kwargs_tuple): + logger.warning( + f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT" + ) + # compile the kernel with the provided parameters + jit_kernel = self.jit_compile() + autotuner_result = AutotuneResult( + libcode=jit_kernel.get_kernel_source(), + func=jit_kernel.prim_func, + kernel=jit_kernel) + self._memory_cache[key] = autotuner_result + return autotuner_result + num_workers = max(1, int(get_available_cpu_count() * 0.9)) pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) futures = [] @@ -657,6 +680,16 @@ def autotune( # This is the new public interface This decorator can be used without arguments (e.g., `@tilelang.jit`): Applies JIT compilation with default settings. + + Tips: + - If you want to skip the auto-tuning process, you can set override the tunable parameters in the function signature. + ```python + if enable_autotune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal) + else: + kernel = flashattn( + batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256) + ``` Parameters ---------- diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index 5e0c67598..f2890229a 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -144,12 +144,12 @@ class AutotuneResult: func: Optimized function. kernel: Compiled kernel function. """ - latency: float - config: dict - ref_latency: float - libcode: str - func: Callable - kernel: Callable + latency: Optional[float] = None + config: Optional[dict] = None + ref_latency: Optional[float] = None + libcode: Optional[str] = None + func: Optional[Callable] = None + kernel: Optional[Callable] = None def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel): """ From a7d6ee45bf15b753ab488ac890a3c08d28375d36 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 8 Jul 2025 07:56:12 +0000 Subject: [PATCH 3/7] [Refactor] Update configuration handling in benchmark scripts - Refactored the `get_configs` functions in benchmark scripts to accept a variable argument list, improving flexibility in configuration management. - Enhanced the `matmul` and `flashattn` functions to utilize the updated configuration approach, streamlining parameter handling for kernel tuning. - Added `@autotune` decorators to relevant functions, ensuring consistent autotuning behavior across benchmarks. - Cleaned up redundant code and improved overall readability in the affected files. --- benchmark/matmul/benchmark_matmul.py | 192 ++++++++---------- .../matmul/benchmark_matmul_intrinsic.py | 38 ++-- benchmark/matmul_fp8/benchmark_matmul.py | 190 ++++++++--------- .../example_mha_fwd_bhsd_wgmma_pipelined.py | 22 +- tilelang/autotuner/__init__.py | 19 +- 5 files changed, 209 insertions(+), 252 deletions(-) diff --git a/benchmark/matmul/benchmark_matmul.py b/benchmark/matmul/benchmark_matmul.py index aa98cae12..cd1142ac1 100644 --- a/benchmark/matmul/benchmark_matmul.py +++ b/benchmark/matmul/benchmark_matmul.py @@ -32,7 +32,7 @@ def ref_program(A, B): return A @ B.T -def get_configs(M, N, K, with_roller=False): +def get_configs(args, kwargs): """ Generate a list of configuration dictionaries that will be used for tuning. @@ -47,6 +47,8 @@ def get_configs(M, N, K, with_roller=False): Each configuration dict includes various block sizes, pipeline stages, thread numbers, and other parameters to explore during autotuning. """ + M, N, K, with_roller = args[:4] + if with_roller: from tilelang.carver.template import MatmulTemplate from tilelang.carver.arch import CUDA @@ -104,7 +106,26 @@ def get_configs(M, N, K, with_roller=False): return configs -def matmul(M, N, K, with_roller): + +@autotune( + configs=get_configs, + warmup=3, + rep=20, +) +@jit(out_idx=[2],) +def matmul( + M, + N, + K, + with_roller, + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + policy=None, + enable_rasteration=None, +): """ Create an autotuned matrix multiplication kernel for matrices of shape: - A: (M, K) @@ -131,117 +152,68 @@ def matmul(M, N, K, with_roller): The baseline latency of the reference program (for computing speedup). """ - # Decorate the kernel with autotune & jit, specifying: - # - Tuning config list - # - Profiling keys - # - Warmup and repetition counts for better measurement - # - A reference program for correctness verification - # - The "tvm" profiler backend - # - HIP as the compilation target (modify as needed for your hardware) - - @autotune( - configs=get_configs(M, N, K, with_roller), - warmup=3, - rep=20, - ) - @jit(out_idx=[2],) - def kernel( - block_M=None, - block_N=None, - block_K=None, - num_stages=None, - thread_num=None, - policy=None, - enable_rasteration=None, + + # Use half-precision for input data to reduce memory bandwidth, + # accumulate in float for better numerical accuracy + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ - The actual kernel to compute C = A @ B^T. - - Parameters - ---------- - block_M : int - Block size in M dimension. - block_N : int - Block size in N dimension. - block_K : int - Block size in K dimension. - num_stages : int - Number of pipelined stages (for asynchronous load). - thread_num : int - Number of threads to use per block. - enable_rasteration : bool - Whether to enable rasterization (swizzling) optimization. - k_pack : int - K dimension packing factor to improve memory coalescing. - - Returns - ------- - Function - A TVM Tensor Language function (T.prim_func) that computes matmul. + The compiled TVM function for block-level matrix multiplication. + + - We divide the entire (M, N) domain into blocks of shape + (block_M, block_N). + - Each block has its own allocated shared memory for sub-blocks + of A and B. + - The partial results go into C_local, and then we copy them back + to global memory C. """ - # Use half-precision for input data to reduce memory bandwidth, - # accumulate in float for better numerical accuracy - dtype = "float16" - accum_dtype = "float" - - @T.prim_func - def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), - ): - """ - The compiled TVM function for block-level matrix multiplication. - - - We divide the entire (M, N) domain into blocks of shape - (block_M, block_N). - - Each block has its own allocated shared memory for sub-blocks - of A and B. - - The partial results go into C_local, and then we copy them back - to global memory C. - """ - # Bind x-dimension to block index in N, - # y-dimension to block index in M. - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - - # Allocate shared memory for A sub-block of shape (block_M, block_K) - A_shared = T.alloc_shared((block_M, block_K), dtype) - # Allocate shared memory for B sub-block of shape (block_N, block_K) - B_shared = T.alloc_shared((block_N, block_K), dtype) - # Allocate a local fragment for intermediate accumulation - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - # Allocate a shared memory for C sub-block of shape (block_M, block_N) - C_shared = T.alloc_shared((block_M, block_N), dtype) - - # Enable (or disable) swizzling optimization - T.use_swizzle(panel_size=10, enable=enable_rasteration) - - # Clear out the accumulation buffer - T.clear(C_local) - - # Loop over sub-blocks in K dimension, pipelined by num_stages - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - # Load a sub-block of A from global memory into A_shared - T.copy(A[by * block_M, k * block_K], A_shared) - # Load a sub-block of B from global memory into B_shared - T.copy(B[bx * block_N, k * block_K], B_shared) - # Perform a partial matrix multiplication: - # C_local += A_shared @ B_shared^T - T.gemm( - A_shared, - B_shared, - C_local, - transpose_B=True, - policy=policy, - ) - # Write back the results from C_local to the global memory C - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return main - - return kernel() + # Bind x-dimension to block index in N, + # y-dimension to block index in M. + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + + # Allocate shared memory for A sub-block of shape (block_M, block_K) + A_shared = T.alloc_shared((block_M, block_K), dtype) + # Allocate shared memory for B sub-block of shape (block_N, block_K) + B_shared = T.alloc_shared((block_N, block_K), dtype) + # Allocate a local fragment for intermediate accumulation + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + # Allocate a shared memory for C sub-block of shape (block_M, block_N) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + # Enable (or disable) swizzling optimization + T.use_swizzle(panel_size=10, enable=enable_rasteration) + + # Clear out the accumulation buffer + T.clear(C_local) + + # Loop over sub-blocks in K dimension, pipelined by num_stages + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Load a sub-block of A from global memory into A_shared + T.copy(A[by * block_M, k * block_K], A_shared) + # Load a sub-block of B from global memory into B_shared + T.copy(B[bx * block_N, k * block_K], B_shared) + # Perform a partial matrix multiplication: + # C_local += A_shared @ B_shared^T + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + policy=policy, + ) + # Write back the results from C_local to the global memory C + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main if __name__ == "__main__": diff --git a/benchmark/matmul/benchmark_matmul_intrinsic.py b/benchmark/matmul/benchmark_matmul_intrinsic.py index 57c164239..d6f4b9ec3 100644 --- a/benchmark/matmul/benchmark_matmul_intrinsic.py +++ b/benchmark/matmul/benchmark_matmul_intrinsic.py @@ -165,7 +165,7 @@ def ref_program(A, B): return A @ B.T -def get_configs(M, N, K, with_roller=False): +def get_configs(args, kwargs): """ Generate a list of configuration dictionaries that will be used for tuning. @@ -180,6 +180,9 @@ def get_configs(M, N, K, with_roller=False): Each configuration dict includes various block sizes, pipeline stages, thread numbers, and other parameters to explore during autotuning. """ + M, N, K = args[:3] + with_roller = args[6] + if with_roller: from tilelang.carver.template import MatmulTemplate from tilelang.carver.arch import CUDA @@ -236,7 +239,14 @@ def get_configs(M, N, K, with_roller=False): return configs - +@autotune( + configs=get_configs, + warmup=3, + rep=5, + ref_prog=ref_program, + skip_check=True, +) +@tl.jit(out_idx=[2],) def matmul( M, N, @@ -245,25 +255,19 @@ def matmul( out_dtype="float16", accum_dtype="float16", with_roller=False, + block_row_warps=None, + block_col_warps=None, + warp_row_tiles=None, + warp_col_tiles=None, + chunk=None, + stage=None, + enable_rasteration=None, ): """Create an autotuned tensor core matrix multiplication kernel.""" - @autotune( - configs=get_configs(M, N, K, with_roller), - warmup=3, - rep=5, - ref_prog=ref_program, - skip_check=True, - ) - @tl.jit(out_idx=[2],) + def kernel( - block_row_warps=None, - block_col_warps=None, - warp_row_tiles=None, - warp_col_tiles=None, - chunk=None, - stage=None, - enable_rasteration=None, + ): return tl_matmul( M, diff --git a/benchmark/matmul_fp8/benchmark_matmul.py b/benchmark/matmul_fp8/benchmark_matmul.py index 8235ef1cd..7fe517f7a 100644 --- a/benchmark/matmul_fp8/benchmark_matmul.py +++ b/benchmark/matmul_fp8/benchmark_matmul.py @@ -33,7 +33,7 @@ def ref_program(A, B): return A.float() @ B.T.float() -def get_configs(M, N, K, with_roller=False): +def get_configs(args, kwargs): """ Generate a list of configuration dictionaries that will be used for tuning. @@ -48,6 +48,8 @@ def get_configs(M, N, K, with_roller=False): Each configuration dict includes various block sizes, pipeline stages, thread numbers, and other parameters to explore during autotuning. """ + M, N, K, with_roller = args[:4] + if with_roller: from tilelang.carver.template import MatmulTemplate from tilelang.carver.arch import CUDA @@ -106,7 +108,25 @@ def get_configs(M, N, K, with_roller=False): return configs -def matmul(M, N, K, with_roller): +@autotune( + configs=get_configs, + warmup=3, + rep=20, +) +@jit(out_idx=[2],) +def matmul( + M, + N, + K, + with_roller, + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + policy=None, + enable_rasteration=None, +): """ Create an autotuned matrix multiplication kernel for matrices of shape: - A: (M, K) @@ -133,117 +153,67 @@ def matmul(M, N, K, with_roller): The baseline latency of the reference program (for computing speedup). """ - # Decorate the kernel with autotune & jit, specifying: - # - Tuning config list - # - Profiling keys - # - Warmup and repetition counts for better measurement - # - A reference program for correctness verification - # - The "tvm" profiler backend - # - HIP as the compilation target (modify as needed for your hardware) - - @autotune( - configs=get_configs(M, N, K, with_roller), - warmup=3, - rep=20, - ) - @jit(out_idx=[2],) - def kernel( - block_M=None, - block_N=None, - block_K=None, - num_stages=None, - thread_num=None, - policy=None, - enable_rasteration=None, + # Use half-precision for input data to reduce memory bandwidth, + # accumulate in float for better numerical accuracy + dtype = "e4m3_float8" + accum_dtype = "float" + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), ): """ - The actual kernel to compute C = A @ B^T. - - Parameters - ---------- - block_M : int - Block size in M dimension. - block_N : int - Block size in N dimension. - block_K : int - Block size in K dimension. - num_stages : int - Number of pipelined stages (for asynchronous load). - thread_num : int - Number of threads to use per block. - enable_rasteration : bool - Whether to enable rasterization (swizzling) optimization. - k_pack : int - K dimension packing factor to improve memory coalescing. - - Returns - ------- - Function - A TVM Tensor Language function (T.prim_func) that computes matmul. + The compiled TVM function for block-level matrix multiplication. + + - We divide the entire (M, N) domain into blocks of shape + (block_M, block_N). + - Each block has its own allocated shared memory for sub-blocks + of A and B. + - The partial results go into C_local, and then we copy them back + to global memory C. """ - # Use half-precision for input data to reduce memory bandwidth, - # accumulate in float for better numerical accuracy - dtype = "e4m3_float8" - accum_dtype = "float" - - @T.prim_func - def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), - ): - """ - The compiled TVM function for block-level matrix multiplication. - - - We divide the entire (M, N) domain into blocks of shape - (block_M, block_N). - - Each block has its own allocated shared memory for sub-blocks - of A and B. - - The partial results go into C_local, and then we copy them back - to global memory C. - """ - # Bind x-dimension to block index in N, - # y-dimension to block index in M. - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - - # Allocate shared memory for A sub-block of shape (block_M, block_K) - A_shared = T.alloc_shared((block_M, block_K), dtype) - # Allocate shared memory for B sub-block of shape (block_N, block_K) - B_shared = T.alloc_shared((block_N, block_K), dtype) - # Allocate a local fragment for intermediate accumulation - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - # Allocate a shared memory for C sub-block of shape (block_M, block_N) - C_shared = T.alloc_shared((block_M, block_N), dtype) - - # Enable (or disable) swizzling optimization - T.use_swizzle(panel_size=10, enable=enable_rasteration) - - # Clear out the accumulation buffer - T.clear(C_local) - - # Loop over sub-blocks in K dimension, pipelined by num_stages - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - # Load a sub-block of A from global memory into A_shared - T.copy(A[by * block_M, k * block_K], A_shared) - # Load a sub-block of B from global memory into B_shared - T.copy(B[bx * block_N, k * block_K], B_shared) - # Perform a partial matrix multiplication: - # C_local += A_shared @ B_shared^T - T.gemm( - A_shared, - B_shared, - C_local, - transpose_B=True, - policy=policy, - ) - # Write back the results from C_local to the global memory C - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return main - - return kernel() + # Bind x-dimension to block index in N, + # y-dimension to block index in M. + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + + # Allocate shared memory for A sub-block of shape (block_M, block_K) + A_shared = T.alloc_shared((block_M, block_K), dtype) + # Allocate shared memory for B sub-block of shape (block_N, block_K) + B_shared = T.alloc_shared((block_N, block_K), dtype) + # Allocate a local fragment for intermediate accumulation + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + # Allocate a shared memory for C sub-block of shape (block_M, block_N) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + # Enable (or disable) swizzling optimization + T.use_swizzle(panel_size=10, enable=enable_rasteration) + + # Clear out the accumulation buffer + T.clear(C_local) + + # Loop over sub-blocks in K dimension, pipelined by num_stages + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Load a sub-block of A from global memory into A_shared + T.copy(A[by * block_M, k * block_K], A_shared) + # Load a sub-block of B from global memory into B_shared + T.copy(B[bx * block_N, k * block_K], B_shared) + # Perform a partial matrix multiplication: + # C_local += A_shared @ B_shared^T + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + policy=policy, + ) + # Write back the results from C_local to the global memory C + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main if __name__ == "__main__": diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index a6a93d914..f33e37d55 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -18,16 +18,18 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[3]) -def flashattn(batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256): +def flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + is_causal, + block_M=128, + block_N=128, + num_stages=2, + threads=256 +): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] diff --git a/tilelang/autotuner/__init__.py b/tilelang/autotuner/__init__.py index e570a25c5..6f9e263be 100644 --- a/tilelang/autotuner/__init__.py +++ b/tilelang/autotuner/__init__.py @@ -265,6 +265,9 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30): sig = inspect.signature(self.fn) parameters = sig.parameters + if isinstance(self.configs, Callable): + self.configs = self.configs(*self._kernel_parameters) + key = self.generate_cache_key(parameters) with self._lock: @@ -527,7 +530,7 @@ class _AutoTunerImplementation: warmup: int = 25 rep: int = 100 timeout: int = 100 - configs: Any = None + configs: Union[Dict, Callable] = None supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto ref_prog: Callable = None supply_prog: Callable = None @@ -539,7 +542,7 @@ class _AutoTunerImplementation: cache_input_tensors: bool = False def __init__(self, - configs: Any, + configs: Union[Dict, Callable], warmup: int = 25, rep: int = 100, timeout: int = 100, @@ -606,7 +609,6 @@ def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Any]: warmup = self.warmup rep = self.rep timeout = self.timeout - configs = self.configs @functools.wraps(fn) def wrapper(*args, **kwargs): @@ -623,7 +625,7 @@ def jit_compile(**config_arg): compile_arguments = fn(__return_compile_arguments=True) autotuner = AutoTuner( - fn, configs=configs).set_profile_args( + fn, configs=self.configs).set_profile_args( supply_type=self.supply_type, ref_prog=self.ref_prog, supply_prog=self.supply_prog, @@ -659,7 +661,7 @@ def jit_compile(**config_arg): def autotune( # This is the new public interface func: Union[Callable[_P, _RProg], PrimFunc, None] = None, *, # Indicates subsequent arguments are keyword-only - configs: Any, + configs: Union[Dict, Callable], # profile arguments warmup: int = 25, rep: int = 100, @@ -697,6 +699,13 @@ def autotune( # This is the new public interface If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter. If using `@tilelang.jit` directly on a function, this argument is implicitly the function to be decorated (and `out_idx` will be `None`). + configs : Dict or Callable + Configuration space to explore during auto-tuning. + warmup : int, optional + Number of warmup iterations before timing. + rep : int, optional + Number of repetitions for timing measurements. + timeout : int, optional target : Union[str, Target], optional Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". target_host : Union[str, Target], optional From bbac4a0786249d72dcf40d841358bbae91be9494 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 8 Jul 2025 07:56:36 +0000 Subject: [PATCH 4/7] [Refactor] Clean up formatting and update subproject commit - Updated the subproject commit reference in the TVM directory to indicate a dirty state. - Removed unnecessary blank lines and improved formatting in the `benchmark_matmul` and `benchmark_matmul_fp8` scripts for better readability. - Streamlined the function definitions in the `flashattn` example script to enhance clarity and maintainability. --- benchmark/matmul/benchmark_matmul.py | 11 ++++------ .../matmul/benchmark_matmul_intrinsic.py | 6 ++--- benchmark/matmul_fp8/benchmark_matmul.py | 9 ++++---- .../example_mha_fwd_bhsd_wgmma_pipelined.py | 22 +++++++++---------- 4 files changed, 20 insertions(+), 28 deletions(-) diff --git a/benchmark/matmul/benchmark_matmul.py b/benchmark/matmul/benchmark_matmul.py index cd1142ac1..f2d7d8c73 100644 --- a/benchmark/matmul/benchmark_matmul.py +++ b/benchmark/matmul/benchmark_matmul.py @@ -106,7 +106,6 @@ def get_configs(args, kwargs): return configs - @autotune( configs=get_configs, warmup=3, @@ -114,9 +113,9 @@ def get_configs(args, kwargs): ) @jit(out_idx=[2],) def matmul( - M, - N, - K, + M, + N, + K, with_roller, block_M=None, block_N=None, @@ -152,7 +151,6 @@ def matmul( The baseline latency of the reference program (for computing speedup). """ - # Use half-precision for input data to reduce memory bandwidth, # accumulate in float for better numerical accuracy dtype = "float16" @@ -176,8 +174,7 @@ def main( """ # Bind x-dimension to block index in N, # y-dimension to block index in M. - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) diff --git a/benchmark/matmul/benchmark_matmul_intrinsic.py b/benchmark/matmul/benchmark_matmul_intrinsic.py index d6f4b9ec3..b141b6e6d 100644 --- a/benchmark/matmul/benchmark_matmul_intrinsic.py +++ b/benchmark/matmul/benchmark_matmul_intrinsic.py @@ -239,6 +239,7 @@ def get_configs(args, kwargs): return configs + @autotune( configs=get_configs, warmup=3, @@ -265,10 +266,7 @@ def matmul( ): """Create an autotuned tensor core matrix multiplication kernel.""" - - def kernel( - - ): + def kernel(): return tl_matmul( M, N, diff --git a/benchmark/matmul_fp8/benchmark_matmul.py b/benchmark/matmul_fp8/benchmark_matmul.py index 7fe517f7a..55eecf957 100644 --- a/benchmark/matmul_fp8/benchmark_matmul.py +++ b/benchmark/matmul_fp8/benchmark_matmul.py @@ -115,9 +115,9 @@ def get_configs(args, kwargs): ) @jit(out_idx=[2],) def matmul( - M, - N, - K, + M, + N, + K, with_roller, block_M=None, block_N=None, @@ -176,8 +176,7 @@ def main( """ # Bind x-dimension to block index in N, # y-dimension to block index in M. - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): # Allocate shared memory for A sub-block of shape (block_M, block_K) A_shared = T.alloc_shared((block_M, block_K), dtype) diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index f33e37d55..a6a93d914 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -18,18 +18,16 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[3]) -def flashattn( - batch, - heads, - seq_q, - seq_kv, - dim, - is_causal, - block_M=128, - block_N=128, - num_stages=2, - threads=256 -): +def flashattn(batch, + heads, + seq_q, + seq_kv, + dim, + is_causal, + block_M=128, + block_N=128, + num_stages=2, + threads=256): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] From 582092874614b13ee01c19ad83e680d8955b16e9 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 8 Jul 2025 09:45:51 +0000 Subject: [PATCH 5/7] [Refactor] Update AutoTuner configuration handling - Modified the AutoTuner class to check if kernel parameters are set before processing tunable arguments, improving robustness in configuration handling. - Enhanced the logic for skipping compilation when tunable parameters are already provided, ensuring efficient use of resources. - Updated comments for clarity and maintainability. --- .../example_convolution_autotune.py | 5 +-- tilelang/autotuner/__init__.py | 34 ++++++++++--------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/examples/convolution/example_convolution_autotune.py b/examples/convolution/example_convolution_autotune.py index 0adc16005..53ab8bd7b 100644 --- a/examples/convolution/example_convolution_autotune.py +++ b/examples/convolution/example_convolution_autotune.py @@ -225,6 +225,7 @@ def convolution(N, block_K, num_stages, thread_num, + enable_rasteration, dtype="float16", accum_dtype="float"): KH, KW = K, K @@ -293,14 +294,14 @@ def main(n: int = 128, with_roller: bool = True): N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p ref_prog = ref_program(S, P, D) - use_autotune = True + if use_autotune: result = get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller) print(result.config) kernel = result.kernel else: config = get_heuristic_config() - kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_dix=[2]) + kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_idx=[2]) profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) tilelang_latency = profiler.do_bench() diff --git a/tilelang/autotuner/__init__.py b/tilelang/autotuner/__init__.py index 6f9e263be..f5b19479b 100644 --- a/tilelang/autotuner/__init__.py +++ b/tilelang/autotuner/__init__.py @@ -400,25 +400,27 @@ def shape_equal(a, b): if len(config_args) == 0: raise ValueError("No configurations to tune, please check your `@autotune` decorator") - # check if the tunable arguments has been tuned. + # check if the tunable arguments has been set. # get the back config argument top_config, *rest = config_args - key_args_tuple, key_kwargs_tuple = self._kernel_parameters - tunable_arguments = [key for key, _ in top_config.items()] - # Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple - if any(key in top_config for key, _ in key_kwargs_tuple): - logger.warning( - f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT" - ) - # compile the kernel with the provided parameters - jit_kernel = self.jit_compile() - autotuner_result = AutotuneResult( - libcode=jit_kernel.get_kernel_source(), - func=jit_kernel.prim_func, - kernel=jit_kernel) - self._memory_cache[key] = autotuner_result - return autotuner_result + if self._kernel_parameters is not None: + key_args_tuple, key_kwargs_tuple = self._kernel_parameters + tunable_arguments = [key for key, _ in top_config.items()] + + # Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple + if any(key in top_config for key, _ in key_kwargs_tuple): + logger.warning( + f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT" + ) + # compile the kernel with the provided parameters + jit_kernel = self.jit_compile() + autotuner_result = AutotuneResult( + libcode=jit_kernel.get_kernel_source(), + func=jit_kernel.prim_func, + kernel=jit_kernel) + self._memory_cache[key] = autotuner_result + return autotuner_result num_workers = max(1, int(get_available_cpu_count() * 0.9)) pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) From 4d91008f5a1c83af2d8016521633ecba462bf83a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 8 Jul 2025 12:59:20 +0000 Subject: [PATCH 6/7] lint fix --- .../flash_attention/example_mha_fwd_bshd.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py index aaf711559..5e0a9dcae 100644 --- a/examples/flash_attention/example_mha_fwd_bshd.py +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -18,7 +18,15 @@ def get_configs(): @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[3]) -def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): +def flashattn(batch, + heads, + seq_len, + dim, + is_causal, + block_M=64, + block_N=64, + num_stages=1, + threads=128): scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" @@ -168,8 +176,15 @@ def main( if (not tune): kernel = flashattn( - batch, heads, seq_len, dim, is_causal, tune=tune)( - block_M=128, block_N=128, num_stages=1, threads=128) + batch, + heads, + seq_len, + dim, + is_causal, + block_M=128, + block_N=128, + num_stages=1, + threads=128) ref_program_processed = partial(ref_program, is_causal=is_causal) profiler = kernel.get_profiler() profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) From f3ac180074ee497ab9c5d0b90e430a012f00600e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 8 Jul 2025 14:05:39 +0000 Subject: [PATCH 7/7] Update TVM subproject commit to indicate dirty state and modify MHA backward test cases - Updated the subproject commit reference in the TVM directory to reflect a dirty state. - Adjusted the `test_mha_bwd` function to use a new configuration for the MHA backward tests, changing the context size from 128 to 256. - Uncommented the main testing function call for potential execution. --- testing/python/kernel/test_tilelang_kernel_mha_bwd.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/testing/python/kernel/test_tilelang_kernel_mha_bwd.py b/testing/python/kernel/test_tilelang_kernel_mha_bwd.py index 80de51e09..fa522e09c 100644 --- a/testing/python/kernel/test_tilelang_kernel_mha_bwd.py +++ b/testing/python/kernel/test_tilelang_kernel_mha_bwd.py @@ -303,10 +303,9 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal): def test_mha_bwd(): - assert_mha_equal(8, 32, 128, 64, False) - assert_mha_equal(8, 32, 128, 64, True) + assert_mha_equal(8, 32, 256, 64, False) + assert_mha_equal(8, 32, 256, 64, True) if __name__ == "__main__": - # tilelang.testing.main() - assert_mha_equal(8, 32, 256, 64, False) + tilelang.testing.main()