From a5bb07ea4b0044ffb01bb0e4615c8b686ad5f558 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 7 Jan 2026 01:36:59 +0000 Subject: [PATCH] [CUTE][SM100] Fix backward gqa on sm100 post mask-mod semantic change stack-info: PR: https://github.com/Dao-AILab/flash-attention/pull/2146, branch: drisspg/stack/11 --- tests/cute/test_mask_mod.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 59409862406..bad320fe5ce 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -95,7 +95,7 @@ def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: tuple[int, i device=q.device, **block_mask_kwargs, ) - out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale) + out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale, enable_gqa=True) return out_ref.transpose(1, 2).contiguous() @@ -809,7 +809,7 @@ def run_flex_reference_bwd(q, k, v, block_mask, grad_out, dtype=None): # Use flex_attention directly without torch.compile for backward tests # torch.compile can hang on certain mask patterns (e.g., mini_causal with float32) - out_ref = flex_attention(q_ref, k_ref, v_ref, block_mask=block_mask) + out_ref = flex_attention(q_ref, k_ref, v_ref, block_mask=block_mask, enable_gqa=True) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), grad_out_ref) # Transpose back to BSHD