From 0bfc49a5affa31873ca23e93d59acf90504a8eca Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 5 Oct 2025 18:50:00 +0800 Subject: [PATCH 1/4] example fix --- examples/flash_attention/example_gqa_bwd.py | 169 ++++++++++++++--- .../example_gqa_bwd_wgmma_pipelined.py | 172 ++++++++++++++++-- examples/flash_attention/example_mha_bwd.py | 152 ++++++++++++++-- .../example_mha_bwd_wgmma_pipelined.py | 163 +++++++++++++++-- src/op/gemm.cc | 5 +- tilelang/language/atomic.py | 17 +- tilelang/language/copy.py | 76 +------- tilelang/language/customize.py | 89 +-------- tilelang/language/utils.py | 86 ++++++++- tilelang/utils/language.py | 9 +- 10 files changed, 682 insertions(+), 256 deletions(-) diff --git a/examples/flash_attention/example_gqa_bwd.py b/examples/flash_attention/example_gqa_bwd.py index 49e60ec86..c3f26a9ce 100644 --- a/examples/flash_attention/example_gqa_bwd.py +++ b/examples/flash_attention/example_gqa_bwd.py @@ -147,7 +147,98 @@ def flash_bwd_post( @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) -def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk)**0.5 + scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) + + T.annotate_layout({ + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + }) + + T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], + 0) + T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim_qk): + if k * block_N + i < seq_len: + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + T.copy(dv, dv_shared) + T.atomic_add(dV[bz, by * block_M:(by+1) * block_M, bx // groups, :], dv_shared) + T.copy(dk, dk_shared) + T.atomic_add(dK[bz, by * block_M:(by+1) * block_M, bx // groups, :], dk_shared) + + return flash_bwd + + +@tilelang.jit(pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, +}) +def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): sm_scale = (1.0 / dim_qk)**0.5 scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) head_kv = heads // groups @@ -171,7 +262,7 @@ def flash_bwd( dK: T.Tensor(dk_shape, dtype), # type: ignore dV: T.Tensor(dv_shape, dtype), # type: ignore ): - with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype) q = T.alloc_shared([block_N, dim_qk], dtype) @@ -202,10 +293,13 @@ def flash_bwd( T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) - for k in T.Pipelined(loop_st, loop_ed, num_stages=1): + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) @@ -213,9 +307,6 @@ def flash_bwd( for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) - T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) - T.clear(dsT) - T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) @@ -244,7 +335,7 @@ def flash_bwd( class _attention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, causal, groups=1): + def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape D_HEAD_V = v.shape[-1] block_M = 128 @@ -253,6 +344,7 @@ def forward(ctx, q, k, v, causal, groups=1): o, lse = mod(q, k, v) ctx.save_for_backward(q, k, v, o, lse) ctx.causal = causal + ctx.use_atomic = use_atomic return o @staticmethod @@ -268,23 +360,39 @@ def maybe_contiguous(x): return x do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] - block_M = 64 + block_M = 128 block_N = 32 mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK) delta = mod_prep(o, do) - kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, - groups) - shape_q = [BATCH, N_CTX, H, D_HEAD_QK] - shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel - shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel - dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) - dk = torch.empty(shape_k, dtype=torch.float16, device=q.device) - dv = torch.empty(shape_v, dtype=torch.float16, device=q.device) - kernel(q, k, v, do, lse, delta, dq, dk, dv) - dq = mod_post(dq) - dk, dv = dk.sum(0), dv.sum(0) - return dq, dk, dv, None, None + + if ctx.use_atomic: + kernel = flashattn_bwd_atomic_add(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, + threads=256, num_stages=2, groups=groups) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] + shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device) + dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = mod_post(dq) + dk = dk.to(torch.float16) + dv = dv.to(torch.float16) + else: + kernel = flashattn_bwd_split(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, + threads=256, num_stages=2, groups=groups) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel + shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.empty(shape_k, dtype=torch.float16, device=q.device) + dv = torch.empty(shape_v, dtype=torch.float16, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = mod_post(dq) + dk, dv = dk.sum(0), dv.sum(0) + + return dq, dk, dv, None, None, None attention = _attention.apply @@ -321,7 +429,8 @@ def main(BATCH: int = 1, D_HEAD_QK: int = 192, D_HEAD_V: int = 128, groups: int = 16, - causal: bool = False): + causal: bool = False, + use_atomic: bool = True): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v @@ -341,7 +450,7 @@ def main(BATCH: int = 1, dO = ( torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()) - O = attention(Q, K, V, causal, groups) + O = attention(Q, K, V, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None dK, K.grad = K.grad.clone(), None @@ -382,7 +491,19 @@ def run1(): parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument('--causal', action='store_true', help='Causal flag') parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument('--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') + parser.add_argument('--use_split', action='store_true', default=False, help='Use split for dK/dV') args = parser.parse_args() - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal) + + # Handle backward compatibility and logic + if args.use_split: + use_atomic = False + elif args.use_atomic: + use_atomic = True + else: + # Default: use atomic + use_atomic = True + + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py index 4083dfadd..912009229 100644 --- a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -147,7 +147,109 @@ def flash_bwd_post( @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) -def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1): +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): + sm_scale = (1.0 / dim_qk)**0.5 + scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim_qk] + k_shape = [batch, seq_len, head_kv, dim_qk] + v_shape = [batch, seq_len, head_kv, dim_v] + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), # type: ignore + K: T.Tensor(k_shape, dtype), # type: ignore + V: T.Tensor(v_shape, dtype), # type: ignore + dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(q_shape, accum_dtype), # type: ignore + dK: T.Tensor(k_shape, accum_dtype), # type: ignore + dV: T.Tensor(v_shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim_qk], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim_qk], dtype) + V_shared = T.alloc_shared([block_M, dim_v], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim_v], dtype) + dv = T.alloc_fragment([block_M, dim_v], accum_dtype) + dk = T.alloc_fragment([block_M, dim_qk], accum_dtype) + dq = T.alloc_fragment([block_N, dim_qk], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) + + T.annotate_layout({ + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + }) + + T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm( + K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], + 0) + T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm( + V_shared, + do, + dsT, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + wg_wait=-1) + T.wait_wgmma(1) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + + T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.wait_wgmma(0) + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) + T.wait_wgmma(0) + for i, j in T.Parallel(block_N, dim_qk): + if k * block_N + i < seq_len: + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + T.copy(dv, dv_shared) + T.atomic_add(dV[bz, by * block_M:(by+1) * block_M, bx // groups, :], dv_shared) + T.copy(dk, dk_shared) + for i, j in T.Parallel(block_M, dim_qk): + T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j]) + + return flash_bwd + + +@tilelang.jit(pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, +}) +def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): sm_scale = (1.0 / dim_qk)**0.5 scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) head_kv = heads // groups @@ -171,7 +273,7 @@ def flash_bwd( dK: T.Tensor(dk_shape, dtype), # type: ignore dV: T.Tensor(dv_shape, dtype), # type: ignore ): - with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim_qk], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype) q = T.alloc_shared([block_N, dim_qk], dtype) @@ -202,7 +304,7 @@ def flash_bwd( T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) - for k in T.Pipelined(loop_st, loop_ed, num_stages=2): + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm( @@ -255,7 +357,7 @@ def flash_bwd( class _attention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, causal, groups=1): + def forward(ctx, q, k, v, causal, groups=1, use_atomic=True): BATCH, N_CTX, H, D_HEAD_QK = q.shape D_HEAD_V = v.shape[-1] block_M = 128 @@ -264,6 +366,7 @@ def forward(ctx, q, k, v, causal, groups=1): o, lse = mod(q, k, v) ctx.save_for_backward(q, k, v, o, lse) ctx.causal = causal + ctx.use_atomic = use_atomic return o @staticmethod @@ -284,18 +387,34 @@ def maybe_contiguous(x): mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK) delta = mod_prep(o, do) - kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, - groups) - shape_q = [BATCH, N_CTX, H, D_HEAD_QK] - shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel - shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel - dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) - dk = torch.empty(shape_k, dtype=torch.float16, device=q.device) - dv = torch.empty(shape_v, dtype=torch.float16, device=q.device) - kernel(q, k, v, do, lse, delta, dq, dk, dv) - dq = mod_post(dq) - dk, dv = dk.sum(0), dv.sum(0) - return dq, dk, dv, None, None + + if ctx.use_atomic: + kernel = flashattn_bwd_atomic_add(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, + threads=256, num_stages=2, groups=groups) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] + shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device) + dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = mod_post(dq) + dk = dk.to(torch.float16) + dv = dv.to(torch.float16) + else: + kernel = flashattn_bwd_split(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, + threads=256, num_stages=2, groups=groups) + shape_q = [BATCH, N_CTX, H, D_HEAD_QK] + shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel + shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel + dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device) + dk = torch.empty(shape_k, dtype=torch.float16, device=q.device) + dv = torch.empty(shape_v, dtype=torch.float16, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = mod_post(dq) + dk, dv = dk.sum(0), dv.sum(0) + + return dq, dk, dv, None, None, None attention = _attention.apply @@ -332,7 +451,8 @@ def main(BATCH: int = 1, D_HEAD_QK: int = 192, D_HEAD_V: int = 128, groups: int = 16, - causal: bool = False): + causal: bool = False, + use_atomic: bool = True): flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V total_flops = 3 * flops_per_qk + 2 * flops_per_v @@ -352,7 +472,7 @@ def main(BATCH: int = 1, dO = ( torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()) - O = attention(Q, K, V, causal, groups) + O = attention(Q, K, V, causal, groups, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None dK, K.grad = K.grad.clone(), None @@ -393,7 +513,19 @@ def run1(): parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--d_head_qk', type=int, default=192, help='Head dimension for Q/K') parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument('--causal', action='store_true', help='Causal flag') parser.add_argument('--groups', type=int, default=16, help='groups') + parser.add_argument('--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') + parser.add_argument('--use_split', action='store_true', default=False, help='Use split for dK/dV') args = parser.parse_args() - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal) + + # Handle backward compatibility and logic + if args.use_split: + use_atomic = False + elif args.use_atomic: + use_atomic = True + else: + # Default: use atomic + use_atomic = True + + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/example_mha_bwd.py b/examples/flash_attention/example_mha_bwd.py index 244c6594a..a44660780 100644 --- a/examples/flash_attention/example_mha_bwd.py +++ b/examples/flash_attention/example_mha_bwd.py @@ -149,7 +149,94 @@ def flash_bwd_post( @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) -def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=128, num_stages=2): + sm_scale = (1.0 / dim)**0.5 + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def flash_bwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, accum_dtype), # type: ignore + dV: T.Tensor(shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], accum_dtype) + + T.annotate_layout({ + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + }) + T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], + 0) + T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim): + if k * block_N + i < seq_len: + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + T.copy(dv, dv_shared) + T.atomic_add(dV[bz, by * block_M:(by+1) * block_M, bx, :], dv_shared) + T.copy(dk, dk_shared) + T.atomic_add(dK[bz, by * block_M:(by+1) * block_M, bx, :], dk_shared) + + return flash_bwd + + +@tilelang.jit(pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, +}) +def flashattn_bwd_split(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=128, num_stages=2): sm_scale = (1.0 / dim)**0.5 scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] @@ -168,13 +255,9 @@ def flash_bwd( dK: T.Tensor(shape, dtype), # type: ignore dV: T.Tensor(shape, dtype), # type: ignore ): - with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype) - # should not store K to local if dim is large - # K_local = T.alloc_fragment([block_M, dim], dtype) - # K_local_T = T.alloc_fragment([block_M, dim], dtype) - # V_local = T.alloc_fragment([block_M, dim], dtype) q = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_M, dim], dtype) qkT = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -202,7 +285,7 @@ def flash_bwd( T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) - for k in T.Pipelined(loop_st, loop_ed, num_stages=2): + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -242,13 +325,14 @@ def flash_bwd( class _attention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, causal): + def forward(ctx, q, k, v, causal, use_atomic=True): BATCH, N_CTX, H, D_HEAD = q.shape block_M = 64 block_N = 64 if D_HEAD <= 128 else 32 o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v) ctx.save_for_backward(q, k, v, o, lse) ctx.causal = causal + ctx.use_atomic = use_atomic return o @staticmethod @@ -267,14 +351,29 @@ def maybe_contiguous(x): kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD) delta = kernel_prep(o, do) - kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N) - shape = [BATCH, N_CTX, H, D_HEAD] - dq = torch.zeros(shape, dtype=torch.float32, device=q.device) - dk = torch.empty(shape, dtype=torch.float16, device=q.device) - dv = torch.empty(shape, dtype=torch.float16, device=q.device) - kernel(q, k, v, do, lse, delta, dq, dk, dv) - dq = kernel_post(dq) - return dq, dk, dv, None + + if ctx.use_atomic: + kernel = flashattn_bwd_atomic_add(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, + threads=128, num_stages=2) + shape = [BATCH, N_CTX, H, D_HEAD] + dq = torch.zeros(shape, dtype=torch.float32, device=q.device) + dk = torch.zeros(shape, dtype=torch.float32, device=q.device) + dv = torch.zeros(shape, dtype=torch.float32, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = kernel_post(dq) + dk = dk.to(torch.float16) + dv = dv.to(torch.float16) + else: + kernel = flashattn_bwd_split(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, + threads=128, num_stages=2) + shape = [BATCH, N_CTX, H, D_HEAD] + dq = torch.zeros(shape, dtype=torch.float32, device=q.device) + dk = torch.empty(shape, dtype=torch.float16, device=q.device) + dv = torch.empty(shape, dtype=torch.float16, device=q.device) + kernel(q, k, v, do, lse, delta, dq, dk, dv) + dq = kernel_post(dq) + + return dq, dk, dv, None, None attention = _attention.apply @@ -300,7 +399,9 @@ def main( N_CTX: int = 1024, D_HEAD: int = 64, causal: bool = False, + use_atomic: bool = True, ): + print(f"Test with use_atomic: {use_atomic}") flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD total_flops = 5 * flops_per_matmul if causal: @@ -311,7 +412,7 @@ def main( K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) - O = attention(Q, K, V, causal) + O = attention(Q, K, V, causal, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None dK, K.grad = K.grad.clone(), None @@ -327,6 +428,7 @@ def main( assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + print('All checks passed.✅') def run(): O_ref.backward(dO, retain_graph=True) @@ -350,6 +452,18 @@ def run1(): parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument('--causal', action='store_true', help='Causal flag') + parser.add_argument('--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') + parser.add_argument('--use_split', action='store_true', default=False, help='Use split for dK/dV') args = parser.parse_args() - main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) + + # Handle backward compatibility and logic + if args.use_split: + use_atomic = False + elif args.use_atomic: + use_atomic = True + else: + # Default: use atomic + use_atomic = True + + main(args.batch, args.h, args.n_ctx, args.d_head, args.causal, use_atomic) diff --git a/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py index 3af22541d..ed17cab63 100644 --- a/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py @@ -146,7 +146,105 @@ def flash_bwd_post( @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) -def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): +def flashattn_bwd_atomic_add(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=256, num_stages=2): + sm_scale = (1.0 / dim)**0.5 + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def flash_bwd( + Q: T.Tensor(shape, dtype), # type: ignore + K: T.Tensor(shape, dtype), # type: ignore + V: T.Tensor(shape, dtype), # type: ignore + dO: T.Tensor(shape, dtype), # type: ignore + lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore + dQ: T.Tensor(shape, accum_dtype), # type: ignore + dK: T.Tensor(shape, accum_dtype), # type: ignore + dV: T.Tensor(shape, accum_dtype), # type: ignore + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], accum_dtype) + + T.annotate_layout({ + dQ: make_dq_layout(dQ), + K_shared: tilelang.layout.make_swizzled_layout(K_shared), + }) + + T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm( + K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], + 0) + T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.gemm( + V_shared, + do, + dsT, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + wg_wait=-1) + T.wait_wgmma(1) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) + + T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.wait_wgmma(0) + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) + T.wait_wgmma(0) + for i, j in T.Parallel(block_N, dim): + if k * block_N + i < seq_len: + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + T.copy(dv, dv_shared) + T.atomic_add(dV[bz, by * block_M:(by+1) * block_M, bx, :], dv_shared) + T.copy(dk, dk_shared) + T.atomic_add(dK[bz, by * block_M:(by+1) * block_M, bx, :], dk_shared) + + return flash_bwd + + +@tilelang.jit(pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, +}) +def flashattn_bwd_split(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=256, num_stages=2): sm_scale = (1.0 / dim)**0.5 scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] @@ -165,13 +263,9 @@ def flash_bwd( dK: T.Tensor(shape, dtype), # type: ignore dV: T.Tensor(shape, dtype), # type: ignore ): - with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): K_shared = T.alloc_shared([block_M, dim], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype) - # should not store K to local if dim is large - # K_local = T.alloc_fragment([block_M, dim], dtype) - # K_local_T = T.alloc_fragment([block_M, dim], dtype) - # V_local = T.alloc_fragment([block_M, dim], dtype) q = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_M, dim], dtype) qkT = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -200,7 +294,7 @@ def flash_bwd( T.clear(dk) loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 loop_ed = T.ceildiv(seq_len, block_N) - for k in T.Pipelined(loop_st, loop_ed, num_stages=2): + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) T.clear(qkT) T.gemm( @@ -251,7 +345,7 @@ def flash_bwd( class _attention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, causal): + def forward(ctx, q, k, v, causal, use_atomic=True): BATCH, N_CTX, H, D_HEAD = q.shape block_M = 64 block_N = 64 if D_HEAD <= 128 else 32 @@ -259,6 +353,7 @@ def forward(ctx, q, k, v, causal): o, lse = mod(q, k, v) ctx.save_for_backward(q, k, v, o, lse) ctx.causal = causal + ctx.use_atomic = use_atomic return o @staticmethod @@ -277,14 +372,29 @@ def maybe_contiguous(x): mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD) delta = mod_prep(o, do) - mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N) - shape = [BATCH, N_CTX, H, D_HEAD] - dq = torch.zeros(shape, dtype=torch.float32, device=q.device) - dk = torch.empty(shape, dtype=torch.float16, device=q.device) - dv = torch.empty(shape, dtype=torch.float16, device=q.device) - mod(q, k, v, do, lse, delta, dq, dk, dv) - dq = mod_post(dq) - return dq, dk, dv, None + + if ctx.use_atomic: + mod = flashattn_bwd_atomic_add(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, + threads=256, num_stages=2) + shape = [BATCH, N_CTX, H, D_HEAD] + dq = torch.zeros(shape, dtype=torch.float32, device=q.device) + dk = torch.zeros(shape, dtype=torch.float32, device=q.device) + dv = torch.zeros(shape, dtype=torch.float32, device=q.device) + mod(q, k, v, do, lse, delta, dq, dk, dv) + dq = mod_post(dq) + dk = dk.to(torch.float16) + dv = dv.to(torch.float16) + else: + mod = flashattn_bwd_split(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, + threads=256, num_stages=2) + shape = [BATCH, N_CTX, H, D_HEAD] + dq = torch.zeros(shape, dtype=torch.float32, device=q.device) + dk = torch.empty(shape, dtype=torch.float16, device=q.device) + dv = torch.empty(shape, dtype=torch.float16, device=q.device) + mod(q, k, v, do, lse, delta, dq, dk, dv) + dq = mod_post(dq) + + return dq, dk, dv, None, None attention = _attention.apply @@ -310,7 +420,9 @@ def main( N_CTX: int = 1024, D_HEAD: int = 64, causal: bool = False, + use_atomic: bool = True, ): + print(f"Test with use_atomic: {use_atomic}") flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD total_flops = 5 * flops_per_matmul if causal: @@ -321,7 +433,7 @@ def main( K = torch.empty_like(Q).normal_().requires_grad_() V = torch.empty_like(Q).normal_().requires_grad_() dO = torch.randn_like(Q) - O = attention(Q, K, V, causal) + O = attention(Q, K, V, causal, use_atomic) O.backward(dO, retain_graph=True) dQ, Q.grad = Q.grad.clone(), None dK, K.grad = K.grad.clone(), None @@ -337,6 +449,7 @@ def main( assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) + print('All checks passed.✅') def run(): O_ref.backward(dO, retain_graph=True) @@ -360,6 +473,18 @@ def run1(): parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--d_head', type=int, default=64, help='Head dimension') - parser.add_argument('--causal', type=bool, default=False, help='Causal flag') + parser.add_argument('--causal', action='store_true', help='Causal flag') + parser.add_argument('--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') + parser.add_argument('--use_split', action='store_true', default=False, help='Use split for dK/dV') args = parser.parse_args() - main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) + + # Handle backward compatibility and logic + if args.use_split: + use_atomic = False + elif args.use_atomic: + use_atomic = True + else: + # Default: use atomic + use_atomic = True + + main(args.batch, args.h, args.n_ctx, args.d_head, args.causal, use_atomic) diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 0c496376c..972f31ad4 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -286,7 +286,7 @@ std::pair GemmWarpPolicyNode::ComputeWarpPartition( } ICHECK(m_warp * n_warp == num_warps) - << "m_warp * n_warp must equal num_warps"; + << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp << ", n_warp: " << n_warp << ", num_warps: " << num_warps; // Store the computed values in the object's member variables this->m_warp = m_warp; @@ -370,6 +370,9 @@ std::pair GemmWarpPolicyNode::ComputeWarpPartition( } else { ICHECK(0) << "Unknown GemmWarpPolicy"; } + ICHECK(m_warp * n_warp == num_warps) + << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp << ", n_warp: " << n_warp << ", num_warps: " << num_warps; + // Store the computed values in the object's member variables this->m_warp = m_warp; this->n_warp = n_warp; diff --git a/tilelang/language/atomic.py b/tilelang/language/atomic.py index 333cb7ad6..718272395 100644 --- a/tilelang/language/atomic.py +++ b/tilelang/language/atomic.py @@ -3,9 +3,11 @@ """Atomic operations for tilelang.""" import tilelang.language as T -from tvm import ir +from tvm import ir, tir from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op from typing import Optional +from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region +from tilelang.utils.language import get_buffer_region_from_load _MEMORY_ORDER_ID_MAP = { "relaxed": 0, @@ -200,14 +202,17 @@ def get_extent(data): extent = max(src_extent, dst_extent) def _to_region(data, access_type): - from .customize import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region - - if isinstance(data, Var) and T.has_let_value(data): + if isinstance(data, tir.Var) and T.has_let_value(data): data = T.get_let_value(data) - if isinstance(data, Buffer): + if isinstance(data, tir.Buffer): return buffer_to_tile_region(data, access_type) - elif isinstance(data, BufferRegion): + elif isinstance(data, tir.BufferRegion): return buffer_region_to_tile_region(data, access_type, extent) + elif isinstance(data, tir.BufferLoad): + region = get_buffer_region_from_load(data) + if region is None: + return buffer_load_to_tile_region(data, access_type, extent) + return buffer_region_to_tile_region(region, access_type, extent) else: return buffer_load_to_tile_region(data, access_type, extent) diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index c08ca3836..dadeaf7af 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -4,81 +4,7 @@ from tilelang import language as T from tilelang.utils.language import get_buffer_region_from_load from tvm import ir, tir - - -def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr): - """Create a memory region descriptor for tile operations. - - Args: - buffer (tir.BufferLoad): The buffer to create a region for - access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write - *args (tir.PrimExpr): Extent expressions defining the region size - - Returns: - tir.Call: A region descriptor for tile operations - """ - access_type = {"r": 1, "w": 2, "rw": 3}[access_type] - return tir.call_intrin("handle", tir.op.Op.get("tl.region"), buffer, access_type, *args) - - -def buffer_to_tile_region(buffer: tir.Buffer, access_type: str): - """Convert a TVM buffer to a tile region descriptor. - - Args: - buffer (tir.Buffer): The buffer to convert - access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write - - Returns: - tir.Call: A region descriptor covering the entire buffer - """ - mins = [0 for _ in buffer.shape] - extents = [x for x in buffer.shape] - return region(T.BufferLoad(buffer, mins), access_type, *extents) - - -def buffer_load_to_tile_region(load: tir.BufferLoad, access_type: str, extents: List[tir.PrimExpr]): - """Convert a buffer load operation to a tile region descriptor. - - Args: - load (tir.BufferLoad): The buffer load operation - access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write - extents (List[tir.PrimExpr]): List of expressions defining the region size - - Returns: - tir.Call: A region descriptor for the loaded area - """ - indices = load.indices - if len(indices) > len(extents): - # (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, " - # f"region will be expanded in the last 2 dimensions") - new_extents = [] - for _ in range(len(indices) - len(extents)): - new_extents.append(1) - for extent in extents: - new_extents.append(extent) - extents = new_extents - assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}" - return region(load, access_type, *extents) - - -def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str, - extents: List[tir.PrimExpr]): - """Convert a buffer region to a tile region descriptor. - - Args: - buffer_region (tir.BufferRegion): The buffer region to convert - access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write - - Returns: - tir.Call: A region descriptor for the specified buffer region - """ - mins = [x.min for x in buffer_region.region] - region_extents = [x.extent for x in buffer_region.region] - assert len(region_extents) >= len( - extents - ), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}" - - return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents) +from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region def copy(src: Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion], diff --git a/tilelang/language/customize.py b/tilelang/language/customize.py index 8492e9ff5..ac9b36a50 100644 --- a/tilelang/language/customize.py +++ b/tilelang/language/customize.py @@ -4,94 +4,7 @@ from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, op from typing import List, Union from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 - - -def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): - """ - Create a tile memory-region descriptor for a BufferLoad. - - Maps access_type ('r', 'w', 'rw') to the numeric codes expected by the `tl.region` intrinsic - (1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents. - - Parameters: - buffer (tir.BufferLoad): The BufferLoad that identifies the underlying buffer and indices. - access_type (str): One of 'r', 'w', or 'rw' indicating read, write, or read-write access. - *args (tir.PrimExpr): Extent expressions for each region dimension. - - Returns: - tir.Call: A call to the `tl.region` intrinsic describing the memory region. - - Raises: - KeyError: If access_type is not one of 'r', 'w', or 'rw'. - """ - access_type = {"r": 1, "w": 2, "rw": 3}[access_type] - return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args) - - -def buffer_to_tile_region(buffer: Buffer, access_type: str): - """Convert a TVM buffer to a tile region descriptor. - - Args: - buffer (tir.Buffer): The buffer to convert - access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write - - Returns: - tir.Call: A region descriptor covering the entire buffer - """ - mins = [0 for _ in buffer.shape] - extents = [x for x in buffer.shape] - return region(T.BufferLoad(buffer, mins), access_type, *extents) - - -def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List[PrimExpr]): - """Convert a buffer load operation to a tile region descriptor. - - Args: - load (tir.BufferLoad): The buffer load operation - access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write - extents (List[tir.PrimExpr]): List of expressions defining the region size - - Returns: - tir.Call: A region descriptor for the loaded area - """ - indices = load.indices - if len(indices) > len(extents): - # (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, " - # f"region will be expanded in the last 2 dimensions") - new_extents = [] - for _ in range(len(indices) - len(extents)): - new_extents.append(1) - for extent in extents: - new_extents.append(extent) - extents = new_extents - assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}" - return region(load, access_type, *extents) - - -def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str, - extents: List[PrimExpr]): - """ - Create a tl region descriptor for the given BufferRegion. - - Parameters: - buffer_region (tir.BufferRegion): Source buffer region whose `region` items provide mins and extents. - access_type (str): Access mode: "r", "w", or "rw". - extents (List[PrimExpr]): Requested extents; must have length <= the number of extents in buffer_region.region. - - Returns: - tir.Call: A tile-region descriptor (tl.region) covering the buffer_region. - - Raises: - AssertionError: If the number of extents in buffer_region.region is smaller than len(extents). - """ - mins = [x.min for x in buffer_region.region] - region_extents = [x.extent for x in buffer_region.region] - assert len(region_extents) >= len( - extents - ), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}" - - return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents) - +from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr: """Perform a 4-element dot product with accumulation (DP4A). diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index d896726e6..e97af60d8 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -1,8 +1,92 @@ from tilelang import tvm as tvm from typing import List -from tvm.tir import PrimExpr +from tvm import tir +from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, op +from tilelang import language as T +def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): + """ + Create a tile memory-region descriptor for a BufferLoad. + + Maps access_type ('r', 'w', 'rw') to the numeric codes expected by the `tl.region` intrinsic + (1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents. + + Parameters: + buffer (tir.BufferLoad): The BufferLoad that identifies the underlying buffer and indices. + access_type (str): One of 'r', 'w', or 'rw' indicating read, write, or read-write access. + *args (tir.PrimExpr): Extent expressions for each region dimension. + + Returns: + tir.Call: A call to the `tl.region` intrinsic describing the memory region. + + Raises: + KeyError: If access_type is not one of 'r', 'w', or 'rw'. + """ + access_type = {"r": 1, "w": 2, "rw": 3}[access_type] + return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args) + +def buffer_to_tile_region(buffer: Buffer, access_type: str): + """Convert a TVM buffer to a tile region descriptor. + + Args: + buffer (tir.Buffer): The buffer to convert + access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write + + Returns: + tir.Call: A region descriptor covering the entire buffer + """ + mins = [0 for _ in buffer.shape] + extents = [x for x in buffer.shape] + return region(T.BufferLoad(buffer, mins), access_type, *extents) + +def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List[PrimExpr]): + """Convert a buffer load operation to a tile region descriptor. + + Args: + load (tir.BufferLoad): The buffer load operation + access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write + extents (List[tir.PrimExpr]): List of expressions defining the region size + + Returns: + tir.Call: A region descriptor for the loaded area + """ + indices = load.indices + print("indices", indices) + print("extents", extents) + if len(indices) > len(extents): + # (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, " + # f"region will be expanded in the last 2 dimensions") + new_extents = [] + for _ in range(len(indices) - len(extents)): + new_extents.append(1) + for extent in extents: + new_extents.append(extent) + extents = new_extents + print("after extents", extents) + assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}" + return region(load, access_type, *extents) + + +def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str, + extents: List[tir.PrimExpr]): + """Convert a buffer region to a tile region descriptor. + + Args: + buffer_region (tir.BufferRegion): The buffer region to convert + access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write + + Returns: + tir.Call: A region descriptor for the specified buffer region + """ + mins = [x.min for x in buffer_region.region] + region_extents = [x.extent for x in buffer_region.region] + assert len(region_extents) >= len( + extents + ), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}" + + return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents) + def index_to_coordinates(index, shape) -> List[PrimExpr]: """ Convert a flat (linear) index into multi-dimensional coordinates for a given shape. diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index ab24d5161..82b969460 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -132,7 +132,10 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> Optional[tir.Buf buffer, indices = buffer_load.buffer, buffer_load.indices regions = [] for indice in indices: - if not isinstance(indice, tir.Ramp): - return None - regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) + if isinstance(indice, tir.Ramp): + regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) + elif isinstance(indice, tir.PrimExpr): + regions.append(ir.Range.from_min_extent(indice, 1)) + else: + raise ValueError("Unsupported type: ", type(indice)) return tir.BufferRegion(buffer, regions) From 9a4a359f5836d3ed131281935e2d1708f80ccbf8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 5 Oct 2025 18:56:35 +0800 Subject: [PATCH 2/4] lint fix --- examples/flash_attention/example_gqa_bwd.py | 65 +++++++++++++++---- .../example_gqa_bwd_wgmma_pipelined.py | 63 +++++++++++++++--- examples/flash_attention/example_mha_bwd.py | 38 ++++++++--- .../example_mha_bwd_wgmma_pipelined.py | 38 ++++++++--- src/op/gemm.cc | 6 +- tilelang/language/copy.py | 2 +- tilelang/language/customize.py | 4 +- tilelang/language/utils.py | 5 +- 8 files changed, 174 insertions(+), 47 deletions(-) diff --git a/examples/flash_attention/example_gqa_bwd.py b/examples/flash_attention/example_gqa_bwd.py index c3f26a9ce..d529925c7 100644 --- a/examples/flash_attention/example_gqa_bwd.py +++ b/examples/flash_attention/example_gqa_bwd.py @@ -147,7 +147,17 @@ def flash_bwd_post( @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) -def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): +def flashattn_bwd_atomic_add(batch, + heads, + seq_len, + dim_qk, + dim_v, + is_causal, + block_M, + block_N, + threads=256, + num_stages=2, + groups=1): sm_scale = (1.0 / dim_qk)**0.5 scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) head_kv = heads // groups @@ -228,9 +238,9 @@ def flash_bwd( if k * block_N + i < seq_len: T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, by * block_M:(by+1) * block_M, bx // groups, :], dv_shared) + T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, by * block_M:(by+1) * block_M, bx // groups, :], dk_shared) + T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) return flash_bwd @@ -238,7 +248,17 @@ def flash_bwd( @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) -def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): +def flashattn_bwd_split(batch, + heads, + seq_len, + dim_qk, + dim_v, + is_causal, + block_M, + block_N, + threads=256, + num_stages=2, + groups=1): sm_scale = (1.0 / dim_qk)**0.5 scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) head_kv = heads // groups @@ -367,8 +387,18 @@ def maybe_contiguous(x): delta = mod_prep(o, do) if ctx.use_atomic: - kernel = flashattn_bwd_atomic_add(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, - threads=256, num_stages=2, groups=groups) + kernel = flashattn_bwd_atomic_add( + BATCH, + H, + N_CTX, + D_HEAD_QK, + D_HEAD_V, + ctx.causal, + block_M, + block_N, + threads=256, + num_stages=2, + groups=groups) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -380,8 +410,18 @@ def maybe_contiguous(x): dk = dk.to(torch.float16) dv = dv.to(torch.float16) else: - kernel = flashattn_bwd_split(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, - threads=256, num_stages=2, groups=groups) + kernel = flashattn_bwd_split( + BATCH, + H, + N_CTX, + D_HEAD_QK, + D_HEAD_V, + ctx.causal, + block_M, + block_N, + threads=256, + num_stages=2, + groups=groups) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel @@ -493,8 +533,10 @@ def run1(): parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') parser.add_argument('--causal', action='store_true', help='Causal flag') parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument('--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument('--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument( + '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') + parser.add_argument( + '--use_split', action='store_true', default=False, help='Use split for dK/dV') args = parser.parse_args() # Handle backward compatibility and logic @@ -506,4 +548,5 @@ def run1(): # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, + use_atomic) diff --git a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py index 912009229..00bf5034f 100644 --- a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -147,7 +147,17 @@ def flash_bwd_post( @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) -def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): +def flashattn_bwd_atomic_add(batch, + heads, + seq_len, + dim_qk, + dim_v, + is_causal, + block_M, + block_N, + threads=256, + num_stages=2, + groups=1): sm_scale = (1.0 / dim_qk)**0.5 scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) head_kv = heads // groups @@ -238,7 +248,7 @@ def flash_bwd( if k * block_N + i < seq_len: T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, by * block_M:(by+1) * block_M, bx // groups, :], dv_shared) + T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dv_shared) T.copy(dk, dk_shared) for i, j in T.Parallel(block_M, dim_qk): T.atomic_add(dK[bz, by * block_M + i, bx // groups, j], dk_shared[i, j]) @@ -249,7 +259,17 @@ def flash_bwd( @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) -def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1): +def flashattn_bwd_split(batch, + heads, + seq_len, + dim_qk, + dim_v, + is_causal, + block_M, + block_N, + threads=256, + num_stages=2, + groups=1): sm_scale = (1.0 / dim_qk)**0.5 scale = (1.0 / dim_qk)**0.5 * 1.44269504 # log2(e) head_kv = heads // groups @@ -389,8 +409,18 @@ def maybe_contiguous(x): delta = mod_prep(o, do) if ctx.use_atomic: - kernel = flashattn_bwd_atomic_add(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, - threads=256, num_stages=2, groups=groups) + kernel = flashattn_bwd_atomic_add( + BATCH, + H, + N_CTX, + D_HEAD_QK, + D_HEAD_V, + ctx.causal, + block_M, + block_N, + threads=256, + num_stages=2, + groups=groups) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK] shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V] @@ -402,8 +432,18 @@ def maybe_contiguous(x): dk = dk.to(torch.float16) dv = dv.to(torch.float16) else: - kernel = flashattn_bwd_split(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, - threads=256, num_stages=2, groups=groups) + kernel = flashattn_bwd_split( + BATCH, + H, + N_CTX, + D_HEAD_QK, + D_HEAD_V, + ctx.causal, + block_M, + block_N, + threads=256, + num_stages=2, + groups=groups) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel @@ -515,8 +555,10 @@ def run1(): parser.add_argument('--d_head_v', type=int, default=128, help='Head dimension for V') parser.add_argument('--causal', action='store_true', help='Causal flag') parser.add_argument('--groups', type=int, default=16, help='groups') - parser.add_argument('--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument('--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument( + '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') + parser.add_argument( + '--use_split', action='store_true', default=False, help='Use split for dK/dV') args = parser.parse_args() # Handle backward compatibility and logic @@ -528,4 +570,5 @@ def run1(): # Default: use atomic use_atomic = True - main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) + main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, + use_atomic) diff --git a/examples/flash_attention/example_mha_bwd.py b/examples/flash_attention/example_mha_bwd.py index a44660780..cacb848ff 100644 --- a/examples/flash_attention/example_mha_bwd.py +++ b/examples/flash_attention/example_mha_bwd.py @@ -149,7 +149,15 @@ def flash_bwd_post( @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) -def flashattn_bwd_atomic_add(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=128, num_stages=2): +def flashattn_bwd_atomic_add(batch, + heads, + seq_len, + dim, + is_causal, + block_M, + block_N, + threads=128, + num_stages=2): sm_scale = (1.0 / dim)**0.5 scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] @@ -226,9 +234,9 @@ def flash_bwd( if k * block_N + i < seq_len: T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, by * block_M:(by+1) * block_M, bx, :], dv_shared) + T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, by * block_M:(by+1) * block_M, bx, :], dk_shared) + T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx, :], dk_shared) return flash_bwd @@ -236,7 +244,15 @@ def flash_bwd( @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) -def flashattn_bwd_split(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=128, num_stages=2): +def flashattn_bwd_split(batch, + heads, + seq_len, + dim, + is_causal, + block_M, + block_N, + threads=128, + num_stages=2): sm_scale = (1.0 / dim)**0.5 scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] @@ -353,8 +369,8 @@ def maybe_contiguous(x): delta = kernel_prep(o, do) if ctx.use_atomic: - kernel = flashattn_bwd_atomic_add(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, - threads=128, num_stages=2) + kernel = flashattn_bwd_atomic_add( + BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=128, num_stages=2) shape = [BATCH, N_CTX, H, D_HEAD] dq = torch.zeros(shape, dtype=torch.float32, device=q.device) dk = torch.zeros(shape, dtype=torch.float32, device=q.device) @@ -364,8 +380,8 @@ def maybe_contiguous(x): dk = dk.to(torch.float16) dv = dv.to(torch.float16) else: - kernel = flashattn_bwd_split(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, - threads=128, num_stages=2) + kernel = flashattn_bwd_split( + BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=128, num_stages=2) shape = [BATCH, N_CTX, H, D_HEAD] dq = torch.zeros(shape, dtype=torch.float32, device=q.device) dk = torch.empty(shape, dtype=torch.float16, device=q.device) @@ -453,8 +469,10 @@ def run1(): parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--d_head', type=int, default=64, help='Head dimension') parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument('--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument( + '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') + parser.add_argument( + '--use_split', action='store_true', default=False, help='Use split for dK/dV') args = parser.parse_args() # Handle backward compatibility and logic diff --git a/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py index ed17cab63..44db09f9a 100644 --- a/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py @@ -146,7 +146,15 @@ def flash_bwd_post( @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) -def flashattn_bwd_atomic_add(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=256, num_stages=2): +def flashattn_bwd_atomic_add(batch, + heads, + seq_len, + dim, + is_causal, + block_M, + block_N, + threads=256, + num_stages=2): sm_scale = (1.0 / dim)**0.5 scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] @@ -234,9 +242,9 @@ def flash_bwd( if k * block_N + i < seq_len: T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) T.copy(dv, dv_shared) - T.atomic_add(dV[bz, by * block_M:(by+1) * block_M, bx, :], dv_shared) + T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx, :], dv_shared) T.copy(dk, dk_shared) - T.atomic_add(dK[bz, by * block_M:(by+1) * block_M, bx, :], dk_shared) + T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx, :], dk_shared) return flash_bwd @@ -244,7 +252,15 @@ def flash_bwd( @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) -def flashattn_bwd_split(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=256, num_stages=2): +def flashattn_bwd_split(batch, + heads, + seq_len, + dim, + is_causal, + block_M, + block_N, + threads=256, + num_stages=2): sm_scale = (1.0 / dim)**0.5 scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] @@ -374,8 +390,8 @@ def maybe_contiguous(x): delta = mod_prep(o, do) if ctx.use_atomic: - mod = flashattn_bwd_atomic_add(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, - threads=256, num_stages=2) + mod = flashattn_bwd_atomic_add( + BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=256, num_stages=2) shape = [BATCH, N_CTX, H, D_HEAD] dq = torch.zeros(shape, dtype=torch.float32, device=q.device) dk = torch.zeros(shape, dtype=torch.float32, device=q.device) @@ -385,8 +401,8 @@ def maybe_contiguous(x): dk = dk.to(torch.float16) dv = dv.to(torch.float16) else: - mod = flashattn_bwd_split(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, - threads=256, num_stages=2) + mod = flashattn_bwd_split( + BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=256, num_stages=2) shape = [BATCH, N_CTX, H, D_HEAD] dq = torch.zeros(shape, dtype=torch.float32, device=q.device) dk = torch.empty(shape, dtype=torch.float16, device=q.device) @@ -474,8 +490,10 @@ def run1(): parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--d_head', type=int, default=64, help='Head dimension') parser.add_argument('--causal', action='store_true', help='Causal flag') - parser.add_argument('--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') - parser.add_argument('--use_split', action='store_true', default=False, help='Use split for dK/dV') + parser.add_argument( + '--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV') + parser.add_argument( + '--use_split', action='store_true', default=False, help='Use split for dK/dV') args = parser.parse_args() # Handle backward compatibility and logic diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 972f31ad4..a8f26ef29 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -286,7 +286,8 @@ std::pair GemmWarpPolicyNode::ComputeWarpPartition( } ICHECK(m_warp * n_warp == num_warps) - << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp << ", n_warp: " << n_warp << ", num_warps: " << num_warps; + << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp + << ", n_warp: " << n_warp << ", num_warps: " << num_warps; // Store the computed values in the object's member variables this->m_warp = m_warp; @@ -371,7 +372,8 @@ std::pair GemmWarpPolicyNode::ComputeWarpPartition( ICHECK(0) << "Unknown GemmWarpPolicy"; } ICHECK(m_warp * n_warp == num_warps) - << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp << ", n_warp: " << n_warp << ", num_warps: " << num_warps; + << "m_warp * n_warp must equal num_warps, m_warp: " << m_warp + << ", n_warp: " << n_warp << ", num_warps: " << num_warps; // Store the computed values in the object's member variables this->m_warp = m_warp; diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index dadeaf7af..125cbd18a 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -1,6 +1,6 @@ """The language interface for tl programs.""" -from typing import Union, List, Optional, Literal +from typing import Union, Optional, Literal from tilelang import language as T from tilelang.utils.language import get_buffer_region_from_load from tvm import ir, tir diff --git a/tilelang/language/customize.py b/tilelang/language/customize.py index ac9b36a50..e31cce4a6 100644 --- a/tilelang/language/customize.py +++ b/tilelang/language/customize.py @@ -1,10 +1,10 @@ """The language interface for tl programs.""" import tilelang.language as T -from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, op +from tvm.tir import PrimExpr, Buffer, op from typing import List, Union from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 -from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region + def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr: """Perform a 4-element dot product with accumulation (DP4A). diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index e97af60d8..1d7dfb0c0 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -1,7 +1,7 @@ from tilelang import tvm as tvm from typing import List from tvm import tir -from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, op +from tvm.tir import PrimExpr, Buffer, BufferLoad, op from tilelang import language as T @@ -26,6 +26,7 @@ def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): access_type = {"r": 1, "w": 2, "rw": 3}[access_type] return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args) + def buffer_to_tile_region(buffer: Buffer, access_type: str): """Convert a TVM buffer to a tile region descriptor. @@ -40,6 +41,7 @@ def buffer_to_tile_region(buffer: Buffer, access_type: str): extents = [x for x in buffer.shape] return region(T.BufferLoad(buffer, mins), access_type, *extents) + def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List[PrimExpr]): """Convert a buffer load operation to a tile region descriptor. @@ -87,6 +89,7 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents) + def index_to_coordinates(index, shape) -> List[PrimExpr]: """ Convert a flat (linear) index into multi-dimensional coordinates for a given shape. From c2ec489d7351975afce3b30af923bc280b431f20 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 5 Oct 2025 20:02:21 +0800 Subject: [PATCH 3/4] bug fix --- tilelang/language/utils.py | 3 +-- tilelang/utils/language.py | 7 ++++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index 1d7dfb0c0..358c2c890 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -54,8 +54,7 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List tir.Call: A region descriptor for the loaded area """ indices = load.indices - print("indices", indices) - print("extents", extents) + if len(indices) > len(extents): # (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, " # f"region will be expanded in the last 2 dimensions") diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index 82b969460..2c0b4efad 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -131,11 +131,16 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> Optional[tir.Buf """ buffer, indices = buffer_load.buffer, buffer_load.indices regions = [] + found_ramp: bool = False for indice in indices: if isinstance(indice, tir.Ramp): regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) + found_ramp = True elif isinstance(indice, tir.PrimExpr): regions.append(ir.Range.from_min_extent(indice, 1)) else: raise ValueError("Unsupported type: ", type(indice)) - return tir.BufferRegion(buffer, regions) + if found_ramp: + return tir.BufferRegion(buffer, regions) + else: + return None From 11a46c88cc7d28d0f47cae18afe934c4df500ece Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 5 Oct 2025 20:35:02 +0800 Subject: [PATCH 4/4] reduce test size. --- examples/warp_specialize/example_warp_specialize_flashmla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/warp_specialize/example_warp_specialize_flashmla.py b/examples/warp_specialize/example_warp_specialize_flashmla.py index c9f664efd..c52dd15c1 100644 --- a/examples/warp_specialize/example_warp_specialize_flashmla.py +++ b/examples/warp_specialize/example_warp_specialize_flashmla.py @@ -382,7 +382,7 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): return out -def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64): +def main(batch=1, heads=64, kv_heads=1, kv_ctx=1024, dim=512, pe_dim=64): qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) pv_flops = 2 * batch * heads * kv_ctx * dim total_flops = qk_flops + pv_flops