From ec1d48978b932090fa49e8b0102f4c10f250f378 Mon Sep 17 00:00:00 2001 From: Ping Gong Date: Fri, 24 Apr 2026 11:26:10 +0000 Subject: [PATCH] [CuTe,Sm110] Fix sm110 2cta dQ postprocess --- flash_attn/cute/flash_bwd_postprocess.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 76c856221c5..94f0c88d817 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -63,7 +63,7 @@ def __init__( self.num_threads = num_threads self.AtomLayoutMdQ = AtomLayoutMdQ self.dQ_swapAB = dQ_swapAB - self.use_2cta_instrs = use_2cta_instrs and arch // 10 == 10 and head_dim != 64 + self.use_2cta_instrs = use_2cta_instrs and arch // 10 in [10, 11] and head_dim != 64 self.cluster_size = cluster_size @staticmethod @@ -373,7 +373,7 @@ def kernel( seqlen_q = seqlen.seqlen_q seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) - if const_expr(self.arch // 10 == 10 and self.use_2cta_instrs): + if const_expr(self.arch // 10 in [10, 11] and self.use_2cta_instrs): # 2-CTA: remap dQaccum layout into TMEM view before writing sdQ num_reduce_threads = self.num_threads thr_mma_dsk = tiled_mma.get_slice(tidx)