Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from einops import rearrange, repeat
from bert_padding import pad_input, unpad_input

torch.manual_seed(1)


def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"]
Expand Down Expand Up @@ -525,7 +523,10 @@ def flash_bwd(
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
if k_base * block_N + i < q_current_seqlen:
T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j])
T.atomic_add(
dQ[q_start_idx + k_base * block_N + i, bx, j],
dq[i, j],
memory_order="release")

T.copy(dv, dv_shared)
for i, d in T.Parallel(block_M, dim_v):
Expand Down Expand Up @@ -739,9 +740,9 @@ def main(BATCH: int = 1,
dV_ref, V.grad = V.grad.clone(), None

torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')

def run():
Expand Down Expand Up @@ -784,8 +785,8 @@ def run1():
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
# Default: use split
use_atomic = False

main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal,
use_atomic)
6 changes: 6 additions & 0 deletions examples/flash_attention/test_example_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
import example_mha_fwd_varlen
import example_mha_bwd_wgmma_pipelined
import example_mha_fwd_bhsd
import example_gqa_bwd_tma_reduce_varlen


@tilelang.testing.requires_cuda
def test_example_gqa_bwd_tma_reduce_varlen():
example_gqa_bwd_tma_reduce_varlen.main()

Comment on lines +18 to 21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Test does not exercise “split” path and lacks compute-capability guard.

As written, main() runs atomic path (default True) and may fail on < SM90 GPUs. Call split explicitly and add the same guard used elsewhere.

-@tilelang.testing.requires_cuda
-def test_example_gqa_bwd_tma_reduce_varlen():
-    example_gqa_bwd_tma_reduce_varlen.main()
+@tilelang.testing.requires_cuda
+@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
+def test_example_gqa_bwd_tma_reduce_varlen():
+    example_gqa_bwd_tma_reduce_varlen.main(use_atomic=False)
🤖 Prompt for AI Agents
In examples/flash_attention/test_example_flash_attention.py around lines 18-21,
the test calls example_gqa_bwd_tma_reduce_varlen.main() which by default runs
the atomic path (may fail on GPUs < SM90); modify the test to explicitly invoke
the split path by calling main(split=True) and add the same compute-capability
guard used elsewhere (e.g., @tilelang.testing.requires_sm90 or equivalent) above
the test so it only runs on supported GPUs.


@tilelang.testing.requires_cuda
Expand Down
Loading