From ecf2a262ab9dcadf83698386286004d7bc7562b6 Mon Sep 17 00:00:00 2001 From: johnsonms Date: Mon, 25 May 2026 07:27:43 +0000 Subject: [PATCH 1/3] Fix bwd postprocess 2CTA gating to include sm_11x The 2CTA gating in flash_bwd_postprocess.py used `arch // 10 == 10`, which only matches SM 10.x (B100/B200/B300) and misses SM 11.x (Thor). The rest of the codebase (e.g. interface.py:549, 563, 834) consistently gates Blackwell-family 2CTA features as `arch // 10 in [10, 11]`. Bring the two postprocess sites in line with that convention. Flagged by @jayhshah in #2572 follow-up discussion. --- flash_attn/cute/flash_bwd_postprocess.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 76c856221c5..94f0c88d817 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -63,7 +63,7 @@ def __init__( self.num_threads = num_threads self.AtomLayoutMdQ = AtomLayoutMdQ self.dQ_swapAB = dQ_swapAB - self.use_2cta_instrs = use_2cta_instrs and arch // 10 == 10 and head_dim != 64 + self.use_2cta_instrs = use_2cta_instrs and arch // 10 in [10, 11] and head_dim != 64 self.cluster_size = cluster_size @staticmethod @@ -373,7 +373,7 @@ def kernel( seqlen_q = seqlen.seqlen_q seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) - if const_expr(self.arch // 10 == 10 and self.use_2cta_instrs): + if const_expr(self.arch // 10 in [10, 11] and self.use_2cta_instrs): # 2-CTA: remap dQaccum layout into TMEM view before writing sdQ num_reduce_threads = self.num_threads thr_mma_dsk = tiled_mma.get_slice(tidx) From 5592f240f1fa2e28a2271e2ba1ab232cdbeb4e27 Mon Sep 17 00:00:00 2001 From: johnsonms Date: Mon, 25 May 2026 07:28:24 +0000 Subject: [PATCH 2/3] Include sm_110 in interface.py Blackwell-family heuristics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three sites in interface.py gate Blackwell-family behavior using `arch // 10 == 10`, which appears inconsistent with the rest of the file's `arch // 10 in [10, 11]` convention (used at lines 549, 563, 834, 974, 1035, etc.): - L533: `q_stage` heuristic for Blackwell forward - L579: `use_dedicated_hd256_kernel` (forward) - L1335: `use_dedicated_hd256_kernel` (backward) The dispatch in `_flash_attn_fwd` already routes both sm_10x and sm_11x through the same `FlashAttentionForwardSm100` / MLA classes, so these gates likely should treat them the same. NOTE FOR REVIEWERS: I'm not certain these are all oversight vs. intentional SM100-only paths. If any of them is intentional, please flag so I can revert just that hunk. The FP8 assert at L480 is left untouched on purpose — its error message reads as deliberate. --- flash_attn/cute/interface.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 189ae1faca7..5e8674bf1ad 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -530,7 +530,7 @@ def _flash_attn_fwd( if cu_seqlens_k is None and seqused_k is None: min_seqlen_k = seqlen_k seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead - if arch // 10 == 10: + if arch // 10 in [10, 11]: q_stage = 2 if seqlen_q_packgqa > tile_m else 1 else: q_stage = 1 @@ -575,8 +575,8 @@ def _flash_attn_fwd( and (tile_m % qhead_per_kvhead == 0 or not pack_gqa) ) - # hd=256 2CTA forward uses dedicated kernel (SM100 only) - use_dedicated_hd256_kernel = arch // 10 == 10 and head_dim == 256 and head_dim_v == 256 + # hd=256 2CTA forward uses dedicated kernel (Blackwell family) + use_dedicated_hd256_kernel = arch // 10 in [10, 11] and head_dim == 256 and head_dim_v == 256 use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel if softcap is not None: @@ -1332,7 +1332,7 @@ def _flash_attn_bwd( cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1 use_2cta_instrs = cluster_size==2 - use_dedicated_hd256_kernel = arch // 10 == 10 and head_dim == 256 and head_dim_v == 256 + use_dedicated_hd256_kernel = arch // 10 in [10, 11] and head_dim == 256 and head_dim_v == 256 use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ From 77efc898e496ee9a7cdd51cadc1d0f27551fc8c0 Mon Sep 17 00:00:00 2001 From: johnsonms Date: Mon, 25 May 2026 23:35:02 +0000 Subject: [PATCH 3/3] Apply ruff format to flash_bwd_sm100.py Pre-existing format drift surfaced by pre-commit. Not in the cute_exclude pattern, so it gets auto-fixed when other files in flash_attn/cute/ are touched in the same commit chain. --- flash_attn/cute/flash_bwd_sm100.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 81462e50afd..061ede3d983 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -1423,9 +1423,7 @@ def kernel( ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) - AttentionMaskCls = self._generate_attention_mask_cls( - window_size_left, window_size_right - ) + AttentionMaskCls = self._generate_attention_mask_cls(window_size_left, window_size_right) # EMPTY # (15) if warp_idx == self.empty_warp_id: