Skip to content

[Cute, flex, sm90] fix sm90 flex#2563

Merged
drisspg merged 1 commit into
Dao-AILab:mainfrom
geruome:wzh/fix_flex
May 13, 2026
Merged

[Cute, flex, sm90] fix sm90 flex#2563
drisspg merged 1 commit into
Dao-AILab:mainfrom
geruome:wzh/fix_flex

Conversation

@geruome
Copy link
Copy Markdown
Contributor

@geruome geruome commented May 13, 2026

pass the seqlen para for produce_block_sparse_loads in sm90 fwd.

Break by https://github.com/Dao-AILab/flash-attention/pull/2224

A simple bug reproduce script on sm90:

from functools import partial
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

torch.manual_seed(0)

B, Hq, Hkv, S, D = 1, 32, 8, 2048, 128
q = torch.randn(B, Hq, S, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, Hkv, S, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, Hkv, S, D, device="cuda", dtype=torch.bfloat16)

def mask_mod(b, h, q_idx, kv_idx):
    return (q_idx >= kv_idx) & ((q_idx - kv_idx <= 1023) | (kv_idx < 4))

bm = create_block_mask(mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S, device="cuda", _compile=True)

flash = torch.compile(partial(flex_attention, kernel_options={"BACKEND": "FLASH"}), dynamic=None)
triton = torch.compile(partial(flex_attention, kernel_options={"BACKEND": "TRITON"}), dynamic=None)

out_flash = flash(q, k, v, block_mask=bm, enable_gqa=True)
out_triton = triton(q, k, v, block_mask=bm, enable_gqa=True)
diff = (out_flash.float() - out_triton.float()).abs()
print("max_diff", diff.max().item(), "mean_diff", diff.mean().item())
torch.testing.assert_close(out_flash, out_triton, atol=5e-2, rtol=5e-2)

And I wonder why such a bug was not detected by the CI ??

@drisspg drisspg merged commit 9cee95f into Dao-AILab:main May 13, 2026
reubenconducts pushed a commit to reubenconducts/flash-attention that referenced this pull request Jun 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants