Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,7 @@ def load(

n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)

if n_block_min < n_block_max:
if const_expr(not self.is_split_kv) or n_block_min < n_block_max:
Copy link
Copy Markdown

@Edenzzzz Edenzzzz Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi can I ask why this is a regression? e.g. for the 1st query tile, this will load most of the subsequent kv tiles (instead of skipping) only to be masked out later? Thanks

load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0
page_idx = (
mPageTable[batch_idx, n_block_max - 1]
Expand Down Expand Up @@ -1255,7 +1255,7 @@ def mma(
seqlen = SeqlenInfoCls(batch_idx)
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)

if n_block_min < n_block_max:
if const_expr(not self.is_split_kv) or n_block_min < n_block_max:
for stage in cutlass.range_constexpr(self.q_stage):
# GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1)
# 1. wait for Q0 / Q1
Expand Down Expand Up @@ -1493,7 +1493,7 @@ def softmax_loop(
seqlen = SeqlenInfoCls(batch_idx)
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)

if n_block_min < n_block_max:
if const_expr(not self.is_split_kv) or n_block_min < n_block_max:
mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k)
mask_fn = partial(
mask.apply_mask_sm100,
Expand Down Expand Up @@ -1807,7 +1807,7 @@ def correction_loop(
# Default LSE to -inf for invalid split_idx tiles
stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage

if n_block_min < n_block_max:
if const_expr(not self.is_split_kv) or n_block_min < n_block_max:
# Ignore first signal from softmax as no correction is required
cute.arch.mbarrier_wait(
mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase
Expand Down Expand Up @@ -2132,7 +2132,7 @@ def epilogue_s2g(
seqlen = SeqlenInfoCls(batch_idx)
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)

if n_block_min < n_block_max:
if const_expr(not self.is_split_kv) or n_block_min < n_block_max:
if const_expr(self.is_split_kv):
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
else:
Expand Down
Loading