diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 76c856221c5..94f0c88d817 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -63,7 +63,7 @@ def __init__( self.num_threads = num_threads self.AtomLayoutMdQ = AtomLayoutMdQ self.dQ_swapAB = dQ_swapAB - self.use_2cta_instrs = use_2cta_instrs and arch // 10 == 10 and head_dim != 64 + self.use_2cta_instrs = use_2cta_instrs and arch // 10 in [10, 11] and head_dim != 64 self.cluster_size = cluster_size @staticmethod @@ -373,7 +373,7 @@ def kernel( seqlen_q = seqlen.seqlen_q seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) - if const_expr(self.arch // 10 == 10 and self.use_2cta_instrs): + if const_expr(self.arch // 10 in [10, 11] and self.use_2cta_instrs): # 2-CTA: remap dQaccum layout into TMEM view before writing sdQ num_reduce_threads = self.num_threads thr_mma_dsk = tiled_mma.get_slice(tidx)