Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 65 additions & 10 deletions flash_attn/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,55 @@
from flash_attn.cute.named_barrier import NamedBarrierBwd


# NOTE [SM100 block-sparse empty tiles: mbarrier contract]
#
# For block-sparse SM100 forward, a given (m_block, stage) Q tile can have zero active
# KV blocks (total_block_cnt == 0). In that case there is no seqlen_kv iteration, so
# the softmax warp-group has no row stats to publish.
#
# The correction warp-group seeds fully-masked-row stats and runs the usual correction
# epilogue so output/LSE have well-defined values. Both warp-groups must still perform
# the softmax<->correction mbarrier handshake so phases advance correctly across
# empty->empty and empty->non-empty tile sequences.
#
# In the no-sink case, this corresponds to the usual fully-masked-row convention:
# output is zero and LSE is -inf.
#
# Barrier contract (each is `mbar_ptr + <offset> + stage`):
#
# Producer/consumer pairs:
# - `mbar_softmax_corr_full` : softmax arrive -> correction wait
# - `mbar_softmax_corr_empty` : correction arrive -> softmax wait
# - `mbar_P_full_O_rescaled` : softmax arrive (+ correction arrive) -> MMA wait
# - `mbar_P_full_2` : softmax arrive -> MMA wait
# - `mbar_corr_epi_full_/empty` : correction <-> epilogue (only when epilogue is separate)
#
# Empty tile (`total_block_cnt == 0`):
# - Softmax: skips the seqlen_kv softmax path entirely (no P stores, no `mbar_P_full_*`).
# It only arrives `mbar_softmax_corr_full` once per stage as a synthetic "no work" signal.
# At the `softmax_loop` level, softmax unconditionally waits `mbar_softmax_corr_empty`
# before each tile (when block-sparse) to drain a prior correction arrival and keep
# phases aligned across non-empty -> empty transitions.
# - Correction: waits `mbar_softmax_corr_full`, seeds stats + runs `correction_epilogue(scale=0)`,
# and arrives `mbar_softmax_corr_empty` (and `mbar_corr_epi_full_/empty` when applicable).
# - No `mbar_P_full_*` barriers are arrived (no P, no MMA O); only the softmax<->correction
# (and correction<->epilogue) handshakes advance phases.
#
# Non-empty tile:
# - Softmax: runs `softmax_step` (produces P) and uses `mbar_softmax_corr_full/empty` to
# publish row_max (during seqlen_kv) and final row stats (once per tile), and to advance phases;
# arrives `mbar_P_full_*` when P is stored.
# - Correction: waits `mbar_softmax_corr_full`, may rescale/release O, arrives `mbar_softmax_corr_empty`
# to ack/advance, and arrives `mbar_P_full_O_rescaled` when MMA can proceed.
#
# Backward (SM100):
# - Empty KV tile: for a given `n_block`, `total_m_block_cnt == 0` means no Q tiles contribute.
# - Both the load and compute loops guard all pipeline work on `process_tile`, so empty tiles
# skip producer/consumer operations entirely (no per-tile mbarrier phase handshake like forward).
# - In the `not dKV_postprocess` path, dK/dV for empty KV tiles are explicitly written as zeros
# even when `process_tile == False` (see `flash_bwd_sm100.py` `should_zero_dKV`).


@cute.jit
def load_block_list(
block_indices: cute.Tensor,
Expand Down Expand Up @@ -671,10 +720,20 @@ def handle_block_sparse_empty_tile_correction_sm100(
gO: Optional[cute.Tensor] = None,
gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,
):
"""Handle the block-sparse case where a tile is fully masked:
* zero staged results
* seed stats
* satisfy the usual barrier protocol so downstream warps continue to make progress.
"""Handle SM100 forward block-sparse tiles with no active KV blocks.

This path is taken when `total_block_cnt == 0`. The softmax warp-group still
arrives `mbar_softmax_corr_full` (synthetic "no work") so the correction
warp-group can:

- seed fully-masked-row stats (row_sum=1; row_max=-inf when tracked) for LSE
- run `correction_epilogue` with `scale=0` so the output tile is written as zeros
(independent of any prior tmem contents)
- wait on `mbar_softmax_corr_full` and arrive `mbar_softmax_corr_empty`
(and `mbar_corr_epi_*` when applicable) so phases stay aligned across tiles

This helper intentionally does not touch `mbar_P_full_*` since no P is produced.
See NOTE [SM100 block-sparse empty tiles: mbarrier contract].
"""
LOG2_E = Float32(math.log2(math.e))

Expand Down Expand Up @@ -708,6 +767,7 @@ def handle_block_sparse_empty_tile_correction_sm100(
acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value
stats[stage] = (row_sum_value, row_max_value, acc_flag)

# See NOTE [SM100 block-sparse empty tiles: mbarrier contract].
cute.arch.mbarrier_wait(
mbar_ptr + mbar_softmax_corr_full_offset + stage,
softmax_corr_consumer_phase,
Expand All @@ -734,11 +794,8 @@ def handle_block_sparse_empty_tile_correction_sm100(
)
if const_expr(gmem_tiled_copy_O is None):
cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage)

softmax_corr_consumer_phase ^= 1
o_corr_consumer_phase ^= 1
corr_epi_producer_phase ^= 1

return (
Expand Down Expand Up @@ -788,10 +845,8 @@ def softmax_block_sparse_sm100(
total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt

if total_block_cnt == 0:
# See NOTE [SM100 block-sparse empty tiles: mbarrier contract].
cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_full_offset + stage_idx)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage_idx)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage_idx)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage_idx)
else:
if curr_mask_block_cnt > 0:
mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]
Expand Down
5 changes: 3 additions & 2 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,8 +1734,8 @@ def softmax_loop(
head_divmod=head_divmod,
)

if has_work:
# Softmax acts as the producer: wait until correction signals the stage is empty
if const_expr(self.use_block_sparsity) or has_work:
# See block_sparse_utils.py NOTE [SM100 block-sparse empty tiles: mbarrier contract].
cute.arch.mbarrier_wait(
mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase
)
Expand Down Expand Up @@ -1785,6 +1785,7 @@ def softmax_loop(
] = softmax.row_max[0]
# if tidx == 0:
# cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0])
# See block_sparse_utils.py NOTE [SM100 block-sparse empty tiles: mbarrier contract].
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage)
# if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0])
else:
Expand Down
56 changes: 56 additions & 0 deletions tests/cute/test_mask_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,5 +1415,61 @@ def causal_mask(b, h, q_idx, kv_idx):
assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}"


@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="SM100/SM110 persistent forward only")
def test_persistent_blocksparse_empty_tiles():
"""Regression test for persistent forward deadlock with highly-sparse block masks.

When most Q-tiles are empty (no active KV blocks), the persistent kernel
deadlocked due to barrier phase desync in the empty-tile paths of both the
softmax and correction warp groups.
"""
torch.manual_seed(5)
batch_size, nheads_q, nheads_kv = 2, 16, 1
seqlen_q, seqlen_k, headdim = 8192, 128, 128
tile_m, tile_n = 128, 128
dtype = torch.bfloat16

sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m
window_size = 64
mask_mod_cute, mask_mod_flex = get_mask_pair(
"sliding_window", seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size,
)

bm = create_block_mask(
mask_mod_flex, batch_size, nheads_q, seqlen_q, seqlen_k,
device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n),
)
(_, _, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_) = bm.as_tuple()
block_sparse_mask_fwd = BlockSparseTensorsTorch(
mask_block_cnt=kv_mask_cnt, mask_block_idx=kv_mask_idx,
full_block_cnt=full_kv_cnt, full_block_idx=full_kv_idx,
block_size=(sparse_tile_m, tile_n),
)

q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, device="cuda", dtype=dtype)
k = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device="cuda", dtype=dtype)
v = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device="cuda", dtype=dtype)

out, lse = _flash_attn_fwd(
q=q, k=k, v=v,
out=torch.empty(batch_size, seqlen_q, nheads_q, headdim, device="cuda", dtype=dtype),
lse=torch.empty(batch_size, nheads_q, seqlen_q, device="cuda", dtype=torch.float32),
cu_seqlens_q=None, cu_seqlens_k=None, seqused_q=None, seqused_k=None,
page_table=None, softmax_scale=1.0 / math.sqrt(headdim),
causal=False, softcap=None,
window_size_left=None, window_size_right=None,
learnable_sink=None,
m_block_size=tile_m, n_block_size=tile_n,
pack_gqa=False, _compute_capability=None,
score_mod=None, mask_mod=mask_mod_cute,
block_sparse_tensors=block_sparse_mask_fwd,
return_lse=True, aux_tensors=None,
)
torch.cuda.synchronize()
assert out.shape == (batch_size, seqlen_q, nheads_q, headdim)
assert not out.isnan().any()



if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])