From 1094592acaa957c2687e1fd83b9fabd701870291 Mon Sep 17 00:00:00 2001 From: johnsonms Date: Mon, 25 May 2026 06:51:15 +0000 Subject: [PATCH 1/2] Use is_family_of for sm_90 and sm_103 arch checks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to #2572 — apply the same is_family_of pattern to the two remaining range-style arch checks for consistency: - flash_fwd_sm90.py:69 (SM 9.x assert) - flash_fwd_sm100.py:195 (is_sm103 flag) Same semantic narrowing as #2572: bare-base SMs (sm_90, sm_103) are excluded. These kernels rely on wgmma / UMMA / 2CTA paths that require the a/f PTX variant anyway, so bare-base targets could not compile. --- flash_attn/cute/flash_fwd_sm100.py | 2 +- flash_attn/cute/flash_fwd_sm90.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 576238bcafb..f452a89c969 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -192,7 +192,7 @@ def __init__( self.mask_vec_size: cutlass.Constexpr = getattr(mask_mod, "__vec_size__", 1) # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) - is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f + is_sm103 = self.arch.is_family_of(Arch.sm_103f) self.is_sm103 = is_sm103 # enable_ex2_emu is derived: True if tuning config has freq > 0, else fallback to default logic _default_enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103 diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index 3d57d6718fc..93bccfa715b 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -66,7 +66,7 @@ def __init__( "Paged KV does not support irregular head dim" ) self.cluster_shape_mn = (1, 1) - assert self.arch >= Arch.sm_90 and self.arch <= Arch.sm_90a, "Only SM 9.x is supported" + assert self.arch.is_family_of(Arch.sm_90a), "Only SM 9.x is supported" def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( From 5833281f8e0805763f30bac540b44f46d046ffc9 Mon Sep 17 00:00:00 2001 From: johnsonms Date: Mon, 25 May 2026 08:30:17 +0000 Subject: [PATCH 2/2] Clarify is_sm103 forward-inclusive semantics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit is_family_of(sm_103f) also matches any future sm_10x with x > 3, not just sm_103a/f. This was raised in PR review (@ocss884) — adding an inline comment clarifying that this forward-inclusive behavior is intentional: the flag gates ex2 emulation, sm_103 (B300) has fast hardware ex2, and later Blackwell variants in the same family are assumed to inherit it. No code-behavior change. --- flash_attn/cute/flash_fwd_sm100.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index f452a89c969..57755d12cb9 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -192,6 +192,10 @@ def __init__( self.mask_vec_size: cutlass.Constexpr = getattr(mask_mod, "__vec_size__", 1) # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) + # NOTE: is_family_of also matches any future sm_10x with x > 3 — intentional. + # The flag gates ex2 emulation; sm_103 (B300) has fast hardware ex2 and later + # Blackwell variants are assumed to inherit this, so forward-inclusion is correct + # despite the literal `is_sm103` name. is_sm103 = self.arch.is_family_of(Arch.sm_103f) self.is_sm103 = is_sm103 # enable_ex2_emu is derived: True if tuning config has freq > 0, else fallback to default logic