diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py index 82d363768..159f0d407 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -7,8 +7,6 @@ from einops import rearrange, repeat from bert_padding import pad_input, unpad_input -torch.manual_seed(1) - def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): assert mode in ["full", "random", "third"] @@ -525,7 +523,10 @@ def flash_bwd( T.gemm(dsT_shared, K_shared, dq, transpose_A=True) for i, j in T.Parallel(block_N, dim_qk): if k_base * block_N + i < q_current_seqlen: - T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j]) + T.atomic_add( + dQ[q_start_idx + k_base * block_N + i, bx, j], + dq[i, j], + memory_order="release") T.copy(dv, dv_shared) for i, d in T.Parallel(block_M, dim_v): @@ -739,9 +740,9 @@ def main(BATCH: int = 1, dV_ref, V.grad = V.grad.clone(), None torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2) - torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) print('All checks passed.✅') def run(): @@ -784,8 +785,8 @@ def run1(): elif args.use_atomic: use_atomic = True else: - # Default: use atomic - use_atomic = True + # Default: use split + use_atomic = False main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic) diff --git a/examples/flash_attention/test_example_flash_attention.py b/examples/flash_attention/test_example_flash_attention.py index 8a58f3b6a..527d89cd0 100644 --- a/examples/flash_attention/test_example_flash_attention.py +++ b/examples/flash_attention/test_example_flash_attention.py @@ -12,6 +12,12 @@ import example_mha_fwd_varlen import example_mha_bwd_wgmma_pipelined import example_mha_fwd_bhsd +import example_gqa_bwd_tma_reduce_varlen + + +@tilelang.testing.requires_cuda +def test_example_gqa_bwd_tma_reduce_varlen(): + example_gqa_bwd_tma_reduce_varlen.main() @tilelang.testing.requires_cuda