From f18e82983a097245c2e3ad30720791a44c22818e Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 28 Oct 2025 19:27:50 +0530 Subject: [PATCH 1/2] Keep the max of all blocks seen in scores_max for stability --- examples/flash_attention/example_mha_fwd_bhsd.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index f07f7a618..6acec5a17 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -84,6 +84,10 @@ def Softmax( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. From 46cecb4192480dca3c02a56be6deca0b79a06a53 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sun, 16 Nov 2025 15:21:43 +0000 Subject: [PATCH 2/2] ruff formatting --- examples/flash_attention/example_mha_fwd_bhsd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index 6acec5a17..2a243525f 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -87,7 +87,7 @@ def Softmax( for i in T.Parallel(block_M): scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - + # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps.