Clamp kv_stage to avoid SMEM overflow for small head_dims on SM100#2594
Conversation
|
Cc: @rm-wu for review |
|
LGTM! I verified again on B200 and all python: 3.12.3
torch: 2.12.0+cu130
device: NVIDIA B200 sm_100 (capability=(10, 0))
Calling flash_attn.cute._flash_attn_fwd with varying head_dim:
(batch=1, seqlen=2048, num_heads=2, bf16, causal=True)
head_dim= 8: OK out.shape=(1, 2048, 2, 8) lse.shape=(1, 2, 2048)
head_dim= 16: OK out.shape=(1, 2048, 2, 16) lse.shape=(1, 2, 2048)
head_dim= 32: OK out.shape=(1, 2048, 2, 32) lse.shape=(1, 2, 2048)
head_dim= 64: OK out.shape=(1, 2048, 2, 64) lse.shape=(1, 2, 2048)
head_dim= 96: OK out.shape=(1, 2048, 2, 96) lse.shape=(1, 2, 2048)
head_dim=128: OK out.shape=(1, 2048, 2, 128) lse.shape=(1, 2, 2048) |
| # actual SMEM usage. At hd_padded=16 the unbounded formula yields 52+ stages and | ||
| # overflows the sm_100a 227 KB cap. Clamp to an upper bound safely past the | ||
| # pipelining-saturation point; attention pipelining saturates well before 32 stages. | ||
| kv_stage = min((224 * 1024 - smem_size_q_o) // smem_size_kv_per_stage, 32) |
There was a problem hiding this comment.
do we actually want to be using 32 kv stages?
There was a problem hiding this comment.
Good catch, @drisspg
32 is essentially "the next round number above 26" — picked to be surgical to the broken case. The unbounded formula only exceeds 32 at head_dim_padded ∈ {8, 16}:
| hd_padded | q=1 | q=2 |
|---|---|---|
| 8 → 16 | 108 | 104 |
| 16 | 108 | 52 |
| 32 | 26 | 24 |
| 64 | 12 | 10 |
| 96 | 7 | 5 |
| 128 | 5 | 3 |
(1CTA path; 2CTA is gated to hd_padded ∈ {128, 192} at interface.py:572 and its max value there is 10, also below 32.)
So min(..., 32) only fires at hd_padded ∈ {8, 16}; anything from 27 to ~50 would be equally surgical. Dropping below 26 starts perturbing kernel staging for hd_padded ∈ {32, 64} (clamp=8 would give 3× fewer stages at hd=32) — and we only have perf data at hd=16, not those.
At hd=16 itself, swept clamp on B200 (batch=4, nheads=16, hd=16, bf16, causal=False):
| clamp | sl=4096 ms | TFLOPS | sl=16384 ms | TFLOPS |
|---|---|---|---|---|
| 2 | 0.3778 | 181.9 | 1.4705 | 186.9 |
| 4 | 0.3818 | 180.0 | 1.4858 | 185.0 |
| 8 | 0.3819 | 180.0 | 1.4860 | 185.0 |
| 16 | 0.3824 | 179.7 | 1.4874 | 184.8 |
| 32 | 0.3830 | 179.4 | 1.4883 | 184.7 |
Everything is within ~1% — clamp value doesn't measurably matter at hd=16 either, so keeping 32 just avoids changing kernel staging anywhere outside the broken case.
Fixes Dao-AILab#2591. The unbounded formula at flash_fwd_sm100.py:335 ignores per-stage state (mbarriers, sScale, pipeline counters) and yields kv_stage values that overflow the sm_100a 227 KB SMEM cap when head_dim_padded=16 (head_dim in {8, ..., 16}). Repro: hd=8/16 + seqlen >= 256 + bf16 fails with cudaErrorInvalidValue ("launch shared memory exceeds current GPU arch sm_100a allowed. Allocated: 233472 bytes. Max: 232448 bytes."). Clamp kv_stage at 32. Surgical to the broken case: the unbounded formula maxes at 26 stages for head_dim_padded >= 32, and the 2CTA gate at interface.py:572 restricts 2CTA to hd_padded in {128, 192} (both no-op), so the clamp only fires at hd_padded in {8, 16}. Verified across 24 configs (hd in {8,16,32,64,96,128} x causal in {T,F} x seqlen in {128,2048}) on B200 with max_err vs torch SDPA <= 0.0078.
5d50760 to
53453d8
Compare
The main test_flash_attn_output parametrizes d over {64, 96, 128, 192, 256}
and never exercises head_dim < 64, even though _validate_head_dims accepts
head_dim >= 8 for sm_100/110. That coverage gap let the SMEM-overflow bug
in Dao-AILab#2591 slip through.
This focused test covers d in {8, 16, 32} x causal x seqlen in {128, 2048}.
The seqlen=2048 cases push q_stage 1->2 (the actual bug trigger); the
seqlen=128 cases also exercise the q_stage=1 boundary that fits on main
today but is structurally adjacent. d=32 serves as a canary against any
future tighter kv_stage clamp regressing it.
53453d8 to
9d68936
Compare
…ao-AILab#2594) * Clamp kv_stage to avoid SMEM overflow for small head_dims on SM100 Fixes Dao-AILab#2591. The unbounded formula at flash_fwd_sm100.py:335 ignores per-stage state (mbarriers, sScale, pipeline counters) and yields kv_stage values that overflow the sm_100a 227 KB SMEM cap when head_dim_padded=16 (head_dim in {8, ..., 16}). Repro: hd=8/16 + seqlen >= 256 + bf16 fails with cudaErrorInvalidValue ("launch shared memory exceeds current GPU arch sm_100a allowed. Allocated: 233472 bytes. Max: 232448 bytes."). Clamp kv_stage at 32. Surgical to the broken case: the unbounded formula maxes at 26 stages for head_dim_padded >= 32, and the 2CTA gate at interface.py:572 restricts 2CTA to hd_padded in {128, 192} (both no-op), so the clamp only fires at hd_padded in {8, 16}. Verified across 24 configs (hd in {8,16,32,64,96,128} x causal in {T,F} x seqlen in {128,2048}) on B200 with max_err vs torch SDPA <= 0.0078. * Add test_flash_attn_small_head_dim regression test The main test_flash_attn_output parametrizes d over {64, 96, 128, 192, 256} and never exercises head_dim < 64, even though _validate_head_dims accepts head_dim >= 8 for sm_100/110. That coverage gap let the SMEM-overflow bug in Dao-AILab#2591 slip through. This focused test covers d in {8, 16, 32} x causal x seqlen in {128, 2048}. The seqlen=2048 cases push q_stage 1->2 (the actual bug trigger); the seqlen=128 cases also exercise the q_stage=1 boundary that fits on main today but is structurally adjacent. d=32 serves as a canary against any future tighter kv_stage clamp regressing it.
Summary
Fixes #2591.
flash_fwd_sm100.py:335computeskv_stagewithout accounting for per-stage overhead (mbar_*,sScale, pipeline counters). Forhead_dim_padded=16, the heuristic can select 52+ stages, exceeding thesm_100ashared memory limit by ~1 KB and causing kernel launch failure:This affects small head dimensions (
head_dim ∈ [8,16]) with non-trivial sequence lengths whereq_stage=2.Fix
Clamp
kv_stageto 32.This is sufficient to avoid the SMEM overflow and does not impact performance in practice, since attention pipelining saturates well before 32 stages. Existing
head_dim_padded >= 64configurations are unchanged because they already use ≤24 stages.Test plan
Added
test_flash_attn_small_head_dimcovering:d ∈ {8,16,32}causal ∈ {True, False}seqlen ∈ {128, 2048}Verified original [FA4, CuTe, SM100]
_flash_attn_fwdfails at launch forhead_dim < 32on B200 #2591 repro now passes.Verified on B200 (
sm_100a): all new tests pass with max error vs Torch SDPA ≤ 0.0078.Confirmed no behavior change for
head_dim_padded >= 64.Notes
An alternative fix would incorporate per-stage overhead directly into the heuristic:
but that reduces stage count across all configurations. The clamp is a smaller, lower-risk fix for this bug.
_validate_head_dimsremains unchanged: small head dimensions are valid onsm_100/110; only the staging heuristic was incorrect.