Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,10 @@ def _setup_attributes(self):
smem_size_k_per_stage = self.n_block_size * self.head_dim_padded * self.k_dtype.width // 8
smem_size_v_per_stage = self.n_block_size * self.head_dim_v_padded * self.v_dtype.width // 8
smem_size_kv_per_stage = max(smem_size_k_per_stage, smem_size_v_per_stage) // self.cta_group_size
kv_stage = (224 * 1024 - smem_size_q_o) // smem_size_kv_per_stage
# Cap small head_dim from over-staging: the 224*1024 budget undercounts
# per-stage state, so at hd_padded=16 the unbounded formula picks 52 stages
# and overflows the 227 KB SMEM cap. No-op for hd_padded >= 32 (max 26).
kv_stage = min((224 * 1024 - smem_size_q_o) // smem_size_kv_per_stage, 32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we actually want to be using 32 kv stages?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, @drisspg
32 is essentially "the next round number above 26" — picked to be surgical to the broken case. The unbounded formula only exceeds 32 at head_dim_padded ∈ {8, 16}:

hd_padded q=1 q=2
8 → 16 108 104
16 108 52
32 26 24
64 12 10
96 7 5
128 5 3

(1CTA path; 2CTA is gated to hd_padded ∈ {128, 192} at interface.py:572 and its max value there is 10, also below 32.)

So min(..., 32) only fires at hd_padded ∈ {8, 16}; anything from 27 to ~50 would be equally surgical. Dropping below 26 starts perturbing kernel staging for hd_padded ∈ {32, 64} (clamp=8 would give 3× fewer stages at hd=32) — and we only have perf data at hd=16, not those.

At hd=16 itself, swept clamp on B200 (batch=4, nheads=16, hd=16, bf16, causal=False):

clamp sl=4096 ms TFLOPS sl=16384 ms TFLOPS
2 0.3778 181.9 1.4705 186.9
4 0.3818 180.0 1.4858 185.0
8 0.3819 180.0 1.4860 185.0
16 0.3824 179.7 1.4874 184.8
32 0.3830 179.4 1.4883 184.7

Everything is within ~1% — clamp value doesn't measurably matter at hd=16 either, so keeping 32 just avoids changing kernel staging anywhere outside the broken case.

if self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and kv_stage == 2:
# For hdim 192,128, we can fit 3 stages if we use uneven_kv_smem
kv_stage = 3
Expand Down
44 changes: 44 additions & 0 deletions tests/cute/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,50 @@ def test_flash_attn_output(
).abs().max().item() + dv_atol


# Regression test for #2591: SMEM overflow at small head_dims on SM100. The main
# test_flash_attn_output skips d < 64, but _validate_head_dims accepts head_dim >= 8
# for sm_100/110, so this path needs coverage. Trigger requires
# seqlen_q_packgqa > tile_m to push q_stage 1->2.
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("d", [8, 16, 32])
@pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (2048, 2048)])
@retry_on_oom
@maybe_fake_tensor_mode(USE_FAKE_TENSOR)
def test_flash_attn_small_head_dim(seqlen_q, seqlen_k, d, causal, dtype):
device = "cuda"
seed = 0
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.empty_cache()
torch.cuda.synchronize()
batch_size = 2
nheads = 2
nheads_kv = nheads
dtype_ref = dtype
q_ref = torch.randn(
batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref
).requires_grad_()
k_ref = torch.randn(
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref
).requires_grad_()
v_ref = torch.randn(
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref
).requires_grad_()
q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)]
out_ref, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal)
out_pt, _ = attention_ref(
q_ref, k_ref, v_ref, None, None, causal=causal, upcast=False, reorder_ops=True
)
out, _ = flash_attn_func(q, k, v, causal=causal)
if is_fake_mode():
return
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
assert (out - out_ref).abs().max().item() <= 2 * (
out_pt - out_ref
).abs().max().item() + fwd_atol


# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
Expand Down