diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 576238bcafb..57755d12cb9 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -192,7 +192,11 @@ def __init__( self.mask_vec_size: cutlass.Constexpr = getattr(mask_mod, "__vec_size__", 1) # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) - is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f + # NOTE: is_family_of also matches any future sm_10x with x > 3 — intentional. + # The flag gates ex2 emulation; sm_103 (B300) has fast hardware ex2 and later + # Blackwell variants are assumed to inherit this, so forward-inclusion is correct + # despite the literal `is_sm103` name. + is_sm103 = self.arch.is_family_of(Arch.sm_103f) self.is_sm103 = is_sm103 # enable_ex2_emu is derived: True if tuning config has freq > 0, else fallback to default logic _default_enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103 diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index 3d57d6718fc..93bccfa715b 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -66,7 +66,7 @@ def __init__( "Paged KV does not support irregular head dim" ) self.cluster_shape_mn = (1, 1) - assert self.arch >= Arch.sm_90 and self.arch <= Arch.sm_90a, "Only SM 9.x is supported" + assert self.arch.is_family_of(Arch.sm_90a), "Only SM 9.x is supported" def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom(