Skip to content

Clamp kv_stage to avoid SMEM overflow for small head_dims on SM100#2594

Merged
Johnsonms merged 2 commits into
Dao-AILab:mainfrom
Johnsonms:johnson/fix-hd-small-smem-overflow
May 28, 2026
Merged

Clamp kv_stage to avoid SMEM overflow for small head_dims on SM100#2594
Johnsonms merged 2 commits into
Dao-AILab:mainfrom
Johnsonms:johnson/fix-hd-small-smem-overflow

Conversation

@Johnsonms
Copy link
Copy Markdown
Collaborator

@Johnsonms Johnsonms commented May 26, 2026

Summary

Fixes #2591.

flash_fwd_sm100.py:335 computes kv_stage without accounting for per-stage overhead (mbar_*, sScale, pipeline counters). For head_dim_padded=16, the heuristic can select 52+ stages, exceeding the sm_100a shared memory limit by ~1 KB and causing kernel launch failure:

cudaErrorInvalidValue: launch shared memory exceeds current GPU arch sm_100a allowed
Allocated: 233472 bytes. Max: 232448 bytes.

This affects small head dimensions (head_dim ∈ [8,16]) with non-trivial sequence lengths where q_stage=2.

Fix

Clamp kv_stage to 32.

This is sufficient to avoid the SMEM overflow and does not impact performance in practice, since attention pipelining saturates well before 32 stages. Existing head_dim_padded >= 64 configurations are unchanged because they already use ≤24 stages.

Test plan

Notes

An alternative fix would incorporate per-stage overhead directly into the heuristic:

(BUDGET - smem_qo) // (smem_kv + PER_STAGE_OVERHEAD)

but that reduces stage count across all configurations. The clamp is a smaller, lower-risk fix for this bug.

_validate_head_dims remains unchanged: small head dimensions are valid on sm_100/110; only the staging heuristic was incorrect.

@Johnsonms
Copy link
Copy Markdown
Collaborator Author

Johnsonms commented May 26, 2026

Cc: @rm-wu for review

@Johnsonms Johnsonms marked this pull request as ready for review May 26, 2026 03:22
@rm-wu
Copy link
Copy Markdown

rm-wu commented May 26, 2026

LGTM! I verified again on B200 and all head_dim in 8, 16, 32, 64, 96, 128 pass with this PR using the repro script in #2591 :

python: 3.12.3
torch:  2.12.0+cu130
device: NVIDIA B200 sm_100  (capability=(10, 0))
Calling flash_attn.cute._flash_attn_fwd with varying head_dim:
(batch=1, seqlen=2048, num_heads=2, bf16, causal=True)
head_dim=  8: OK    out.shape=(1, 2048, 2, 8)  lse.shape=(1, 2, 2048)
head_dim= 16: OK    out.shape=(1, 2048, 2, 16)  lse.shape=(1, 2, 2048)
head_dim= 32: OK    out.shape=(1, 2048, 2, 32)  lse.shape=(1, 2, 2048)
head_dim= 64: OK    out.shape=(1, 2048, 2, 64)  lse.shape=(1, 2, 2048)
head_dim= 96: OK    out.shape=(1, 2048, 2, 96)  lse.shape=(1, 2, 2048)
head_dim=128: OK    out.shape=(1, 2048, 2, 128)  lse.shape=(1, 2, 2048)

@Johnsonms Johnsonms requested review from drisspg, jayhshah and tridao May 26, 2026 23:56
# actual SMEM usage. At hd_padded=16 the unbounded formula yields 52+ stages and
# overflows the sm_100a 227 KB cap. Clamp to an upper bound safely past the
# pipelining-saturation point; attention pipelining saturates well before 32 stages.
kv_stage = min((224 * 1024 - smem_size_q_o) // smem_size_kv_per_stage, 32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we actually want to be using 32 kv stages?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, @drisspg
32 is essentially "the next round number above 26" — picked to be surgical to the broken case. The unbounded formula only exceeds 32 at head_dim_padded ∈ {8, 16}:

hd_padded q=1 q=2
8 → 16 108 104
16 108 52
32 26 24
64 12 10
96 7 5
128 5 3

(1CTA path; 2CTA is gated to hd_padded ∈ {128, 192} at interface.py:572 and its max value there is 10, also below 32.)

So min(..., 32) only fires at hd_padded ∈ {8, 16}; anything from 27 to ~50 would be equally surgical. Dropping below 26 starts perturbing kernel staging for hd_padded ∈ {32, 64} (clamp=8 would give 3× fewer stages at hd=32) — and we only have perf data at hd=16, not those.

At hd=16 itself, swept clamp on B200 (batch=4, nheads=16, hd=16, bf16, causal=False):

clamp sl=4096 ms TFLOPS sl=16384 ms TFLOPS
2 0.3778 181.9 1.4705 186.9
4 0.3818 180.0 1.4858 185.0
8 0.3819 180.0 1.4860 185.0
16 0.3824 179.7 1.4874 184.8
32 0.3830 179.4 1.4883 184.7

Everything is within ~1% — clamp value doesn't measurably matter at hd=16 either, so keeping 32 just avoids changing kernel staging anywhere outside the broken case.

Fixes Dao-AILab#2591. The unbounded formula at flash_fwd_sm100.py:335 ignores
per-stage state (mbarriers, sScale, pipeline counters) and yields kv_stage
values that overflow the sm_100a 227 KB SMEM cap when head_dim_padded=16
(head_dim in {8, ..., 16}). Repro: hd=8/16 + seqlen >= 256 + bf16 fails
with cudaErrorInvalidValue ("launch shared memory exceeds current GPU arch
sm_100a allowed. Allocated: 233472 bytes. Max: 232448 bytes.").

Clamp kv_stage at 32. Surgical to the broken case: the unbounded formula
maxes at 26 stages for head_dim_padded >= 32, and the 2CTA gate at
interface.py:572 restricts 2CTA to hd_padded in {128, 192} (both no-op),
so the clamp only fires at hd_padded in {8, 16}.

Verified across 24 configs (hd in {8,16,32,64,96,128} x causal in {T,F} x
seqlen in {128,2048}) on B200 with max_err vs torch SDPA <= 0.0078.
@Johnsonms Johnsonms force-pushed the johnson/fix-hd-small-smem-overflow branch from 5d50760 to 53453d8 Compare May 28, 2026 02:35
The main test_flash_attn_output parametrizes d over {64, 96, 128, 192, 256}
and never exercises head_dim < 64, even though _validate_head_dims accepts
head_dim >= 8 for sm_100/110. That coverage gap let the SMEM-overflow bug
in Dao-AILab#2591 slip through.

This focused test covers d in {8, 16, 32} x causal x seqlen in {128, 2048}.
The seqlen=2048 cases push q_stage 1->2 (the actual bug trigger); the
seqlen=128 cases also exercise the q_stage=1 boundary that fits on main
today but is structurally adjacent. d=32 serves as a canary against any
future tighter kv_stage clamp regressing it.
@Johnsonms Johnsonms force-pushed the johnson/fix-hd-small-smem-overflow branch from 53453d8 to 9d68936 Compare May 28, 2026 02:39
@Johnsonms Johnsonms requested a review from drisspg May 28, 2026 02:44
@Johnsonms Johnsonms merged commit 0bbb25a into Dao-AILab:main May 28, 2026
@Johnsonms Johnsonms deleted the johnson/fix-hd-small-smem-overflow branch May 30, 2026 06:14
reubenconducts pushed a commit to reubenconducts/flash-attention that referenced this pull request Jun 2, 2026
…ao-AILab#2594)

* Clamp kv_stage to avoid SMEM overflow for small head_dims on SM100

Fixes Dao-AILab#2591. The unbounded formula at flash_fwd_sm100.py:335 ignores
per-stage state (mbarriers, sScale, pipeline counters) and yields kv_stage
values that overflow the sm_100a 227 KB SMEM cap when head_dim_padded=16
(head_dim in {8, ..., 16}). Repro: hd=8/16 + seqlen >= 256 + bf16 fails
with cudaErrorInvalidValue ("launch shared memory exceeds current GPU arch
sm_100a allowed. Allocated: 233472 bytes. Max: 232448 bytes.").

Clamp kv_stage at 32. Surgical to the broken case: the unbounded formula
maxes at 26 stages for head_dim_padded >= 32, and the 2CTA gate at
interface.py:572 restricts 2CTA to hd_padded in {128, 192} (both no-op),
so the clamp only fires at hd_padded in {8, 16}.

Verified across 24 configs (hd in {8,16,32,64,96,128} x causal in {T,F} x
seqlen in {128,2048}) on B200 with max_err vs torch SDPA <= 0.0078.

* Add test_flash_attn_small_head_dim regression test

The main test_flash_attn_output parametrizes d over {64, 96, 128, 192, 256}
and never exercises head_dim < 64, even though _validate_head_dims accepts
head_dim >= 8 for sm_100/110. That coverage gap let the SMEM-overflow bug
in Dao-AILab#2591 slip through.

This focused test covers d in {8, 16, 32} x causal x seqlen in {128, 2048}.
The seqlen=2048 cases push q_stage 1->2 (the actual bug trigger); the
seqlen=128 cases also exercise the q_stage=1 boundary that fits on main
today but is structurally adjacent. d=32 serves as a canary against any
future tighter kv_stage clamp regressing it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FA4, CuTe, SM100] _flash_attn_fwd fails at launch for head_dim < 32 on B200

3 participants