From 23f5b23680c0b17aa68ed58c6919d1c56ba6690b Mon Sep 17 00:00:00 2001 From: johnsonms Date: Thu, 28 May 2026 02:35:35 +0000 Subject: [PATCH 1/2] Clamp kv_stage to avoid SMEM overflow for small head_dims on SM100 Fixes #2591. The unbounded formula at flash_fwd_sm100.py:335 ignores per-stage state (mbarriers, sScale, pipeline counters) and yields kv_stage values that overflow the sm_100a 227 KB SMEM cap when head_dim_padded=16 (head_dim in {8, ..., 16}). Repro: hd=8/16 + seqlen >= 256 + bf16 fails with cudaErrorInvalidValue ("launch shared memory exceeds current GPU arch sm_100a allowed. Allocated: 233472 bytes. Max: 232448 bytes."). Clamp kv_stage at 32. Surgical to the broken case: the unbounded formula maxes at 26 stages for head_dim_padded >= 32, and the 2CTA gate at interface.py:572 restricts 2CTA to hd_padded in {128, 192} (both no-op), so the clamp only fires at hd_padded in {8, 16}. Verified across 24 configs (hd in {8,16,32,64,96,128} x causal in {T,F} x seqlen in {128,2048}) on B200 with max_err vs torch SDPA <= 0.0078. --- flash_attn/cute/flash_fwd_sm100.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 57755d12cb9..82638f341cd 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -336,7 +336,10 @@ def _setup_attributes(self): smem_size_k_per_stage = self.n_block_size * self.head_dim_padded * self.k_dtype.width // 8 smem_size_v_per_stage = self.n_block_size * self.head_dim_v_padded * self.v_dtype.width // 8 smem_size_kv_per_stage = max(smem_size_k_per_stage, smem_size_v_per_stage) // self.cta_group_size - kv_stage = (224 * 1024 - smem_size_q_o) // smem_size_kv_per_stage + # Cap small head_dim from over-staging: the 224*1024 budget undercounts + # per-stage state, so at hd_padded=16 the unbounded formula picks 52 stages + # and overflows the 227 KB SMEM cap. No-op for hd_padded >= 32 (max 26). + kv_stage = min((224 * 1024 - smem_size_q_o) // smem_size_kv_per_stage, 32) if self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and kv_stage == 2: # For hdim 192,128, we can fit 3 stages if we use uneven_kv_smem kv_stage = 3 From 9d68936229d5f68c2fc3dc00cf19d01595731094 Mon Sep 17 00:00:00 2001 From: johnsonms Date: Thu, 28 May 2026 02:35:41 +0000 Subject: [PATCH 2/2] Add test_flash_attn_small_head_dim regression test The main test_flash_attn_output parametrizes d over {64, 96, 128, 192, 256} and never exercises head_dim < 64, even though _validate_head_dims accepts head_dim >= 8 for sm_100/110. That coverage gap let the SMEM-overflow bug in #2591 slip through. This focused test covers d in {8, 16, 32} x causal x seqlen in {128, 2048}. The seqlen=2048 cases push q_stage 1->2 (the actual bug trigger); the seqlen=128 cases also exercise the q_stage=1 boundary that fits on main today but is structurally adjacent. d=32 serves as a canary against any future tighter kv_stage clamp regressing it. --- tests/cute/test_flash_attn.py | 44 +++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 764d7123681..bf881efe1c0 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -438,6 +438,50 @@ def test_flash_attn_output( ).abs().max().item() + dv_atol +# Regression test for #2591: SMEM overflow at small head_dims on SM100. The main +# test_flash_attn_output skips d < 64, but _validate_head_dims accepts head_dim >= 8 +# for sm_100/110, so this path needs coverage. Trigger requires +# seqlen_q_packgqa > tile_m to push q_stage 1->2. +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [8, 16, 32]) +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (2048, 2048)]) +@retry_on_oom +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_small_head_dim(seqlen_q, seqlen_k, d, causal, dtype): + device = "cuda" + seed = 0 + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.empty_cache() + torch.cuda.synchronize() + batch_size = 2 + nheads = 2 + nheads_kv = nheads + dtype_ref = dtype + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ).requires_grad_() + k_ref = torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ).requires_grad_() + v_ref = torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ).requires_grad_() + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + out_ref, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal) + out_pt, _ = attention_ref( + q_ref, k_ref, v_ref, None, None, causal=causal, upcast=False, reorder_ops=True + ) + out, _ = flash_attn_func(q, k, v, causal=causal) + if is_fake_mode(): + return + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert (out - out_ref).abs().max().item() <= 2 * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])