diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 6e030b17615..c4a569fa0d1 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1162,7 +1162,7 @@ def load( n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - if n_block_min < n_block_max: + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 page_idx = ( mPageTable[batch_idx, n_block_max - 1] @@ -1255,7 +1255,7 @@ def mma( seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - if n_block_min < n_block_max: + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: for stage in cutlass.range_constexpr(self.q_stage): # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) # 1. wait for Q0 / Q1 @@ -1493,7 +1493,7 @@ def softmax_loop( seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - if n_block_min < n_block_max: + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) mask_fn = partial( mask.apply_mask_sm100, @@ -1807,7 +1807,7 @@ def correction_loop( # Default LSE to -inf for invalid split_idx tiles stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage - if n_block_min < n_block_max: + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: # Ignore first signal from softmax as no correction is required cute.arch.mbarrier_wait( mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase @@ -2132,7 +2132,7 @@ def epilogue_s2g( seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - if n_block_min < n_block_max: + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: if const_expr(self.is_split_kv): mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: