Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 251 additions & 0 deletions flash_attn/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,254 @@ def softmax_block_sparse_sm100(
s0_s1_sequence_phase,
total_block_cnt == 0,
)


# =============================================================================
# Backward-specific block-sparse helpers (SM100)
# =============================================================================
#
# In backward, iteration is transposed compared to forward:
# - Forward: outer loop over m_blocks (Q tiles), inner loop over n_blocks (KV tiles)
# - Backward: outer loop over n_blocks (KV tiles), inner loop over m_blocks (Q tiles)
#
# The backward block-sparse tensors use "Q direction" indexing:
# - q_block_cnt[batch, head, n_block] → count of m_blocks to process for this KV tile
# - q_block_idx[batch, head, n_block, :] → indices of m_blocks to process
#


@cute.jit
def get_total_q_block_count_bwd(
blocksparse_tensors: BlockSparseTensors,
batch_idx,
head_idx,
n_block,
subtile_factor: cutlass.Constexpr = 1,
m_block_max: int = 0,
):
"""Count total tile iterations for given n_block (KV tile) in backward.

Args:
m_block_max: Maximum m_block index from causal/local masking constraints.
Computed by block_info.get_m_block_min_max() based on sequence lengths
and attention mask type. When > 0, caps the result to ensure we don't
count sparse blocks that fall outside the valid causal/local window.

Returns min(sparse_block_count * subtile_factor, m_block_max) when m_block_max > 0.
"""
q_block_cnt, _, full_q_block_cnt, _ = blocksparse_tensors
total = q_block_cnt[batch_idx, head_idx, n_block]
if const_expr(full_q_block_cnt is not None):
total = total + full_q_block_cnt[batch_idx, head_idx, n_block]
result = total * subtile_factor
if m_block_max > 0:
result = cutlass.min(result, m_block_max)
return result


@cute.jit
def produce_block_sparse_q_loads_bwd_sm100(
blocksparse_tensors: BlockSparseTensors,
batch_idx,
head_idx,
n_block,
# Pipeline states (will be returned after advancing)
producer_state_Q_LSE,
producer_state_dO_dPsum,
# Pipelines
pipeline_Q,
pipeline_LSE,
pipeline_dO,
pipeline_dPsum,
# Load functions
load_K,
load_V,
load_Q,
load_dO,
copy_stats,
# Global tensors for LSE/dPsum
gLSE,
sLSE,
gdPsum,
sdPsum,
# TMA copy bytes for extra_tx_count
tma_copy_bytes_K,
tma_copy_bytes_V,
# Flags for which loads to perform
should_load_Q: cutlass.Constexpr,
should_load_dO: cutlass.Constexpr,
# Subtiling factor and bounds
subtile_factor: cutlass.Constexpr = 1,
m_block_max: int = 0,
):
"""SM100 backward block sparse loading with subtiling.

Returns updated (producer_state_Q_LSE, producer_state_dO_dPsum).
First iteration loads K/V alongside Q/dO; subsequent iterations load only Q/dO.
"""
(
curr_q_cnt,
curr_q_idx,
curr_full_cnt,
curr_full_idx,
loop_count,
) = get_block_sparse_iteration_info_bwd(
blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max
)

for iter_idx in cutlass.range(loop_count, unroll=1):
m_block, _ = get_m_block_from_iter_bwd(
iter_idx,
curr_q_cnt,
curr_q_idx,
curr_full_cnt,
curr_full_idx,
subtile_factor,
)

if iter_idx == 0:
# First block: load K/V alongside Q/dO
if const_expr(should_load_Q):
pipeline_Q.producer_acquire(producer_state_Q_LSE, extra_tx_count=tma_copy_bytes_K)
load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE))
load_Q(m_block, producer_state=producer_state_Q_LSE)
pipeline_Q.producer_commit(producer_state_Q_LSE)
pipeline_LSE.producer_acquire(producer_state_Q_LSE)
with cute.arch.elect_one():
copy_stats(
gLSE[None, m_block],
sLSE[None, producer_state_Q_LSE.index],
mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),
)
producer_state_Q_LSE.advance()
if const_expr(should_load_dO):
pipeline_dO.producer_acquire(
producer_state_dO_dPsum, extra_tx_count=tma_copy_bytes_V
)
load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum))
load_dO(m_block, producer_state=producer_state_dO_dPsum)
pipeline_dO.producer_commit(producer_state_dO_dPsum)
pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
with cute.arch.elect_one():
copy_stats(
gdPsum[None, m_block],
sdPsum[None, producer_state_dO_dPsum.index],
mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum),
)
producer_state_dO_dPsum.advance()
else:
# Subsequent blocks: just load Q/dO (K/V already loaded)
if const_expr(should_load_Q):
pipeline_Q.producer_acquire(producer_state_Q_LSE)
load_Q(m_block, producer_state=producer_state_Q_LSE)
pipeline_Q.producer_commit(producer_state_Q_LSE)
pipeline_LSE.producer_acquire(producer_state_Q_LSE)
with cute.arch.elect_one():
copy_stats(
gLSE[None, m_block],
sLSE[None, producer_state_Q_LSE.index],
mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),
)
producer_state_Q_LSE.advance()
if const_expr(should_load_dO):
pipeline_dO.producer_acquire(producer_state_dO_dPsum)
load_dO(m_block, producer_state=producer_state_dO_dPsum)
pipeline_dO.producer_commit(producer_state_dO_dPsum)
pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
with cute.arch.elect_one():
copy_stats(
gdPsum[None, m_block],
sdPsum[None, producer_state_dO_dPsum.index],
mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum),
)
producer_state_dO_dPsum.advance()

return producer_state_Q_LSE, producer_state_dO_dPsum


@cute.jit
def get_block_sparse_iteration_info_bwd(
blocksparse_tensors: BlockSparseTensors,
batch_idx,
head_idx,
n_block,
subtile_factor: cutlass.Constexpr = 1,
m_block_max: int = 0,
):
"""Extract block-sparse iteration info for backward pass.

Args:
m_block_max: Maximum m_block index from causal/local masking constraints.
Computed by block_info.get_m_block_min_max() based on sequence lengths
and attention mask type. When > 0, caps total_count to ensure we don't
process sparse blocks that fall outside the valid causal/local window.
This combines block sparsity with causal/local masking.

Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count).
"""
q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors
curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]
curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]

if const_expr(full_cnt is not None):
curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]
curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]
else:
curr_full_cnt = Int32(0)
curr_full_idx = None

sparse_block_count = curr_q_cnt
if const_expr(full_cnt is not None):
sparse_block_count = sparse_block_count + curr_full_cnt

total_count = sparse_block_count * subtile_factor
if m_block_max > 0:
total_count = cutlass.min(total_count, m_block_max)

return curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count


@cute.jit
def get_m_block_from_iter_bwd(
iter_idx,
curr_q_cnt,
curr_q_idx: cute.Tensor,
curr_full_cnt,
curr_full_idx: Optional[cute.Tensor],
subtile_factor: cutlass.Constexpr = 1,
):
"""Derive m_block index and is_full_block flag from iteration index.

In backward, we iterate in FORWARD order: masked blocks first (low to high),
then full blocks (low to high). This ensures that when loop_count is capped
to m_block_max, we skip the high (potentially out-of-bounds) m_blocks at the
end of iteration rather than in the middle.

With subtiling (subtile_factor > 1):
- sparse_iter_idx = iter_idx // subtile_factor (which sparse block)
- subtile_offset = iter_idx % subtile_factor (which subtile within sparse block)
- m_block = sparse_m_block * subtile_factor + subtile_offset

Returns (m_block, is_full_block):
- m_block: The actual Q-tile block index (after subtiling)
- is_full_block: True if this is a full block (no mask_mod needed)
Note: All subtiles within a sparse block share the same is_full_block status
"""
sparse_iter_idx = iter_idx // subtile_factor
subtile_offset = iter_idx % subtile_factor

sparse_m_block = Int32(0)
is_full_block = False

# Forward order: process low sparse block indices first
if sparse_iter_idx < curr_q_cnt:
sparse_m_block = curr_q_idx[sparse_iter_idx]
is_full_block = False
else:
full_iter = sparse_iter_idx - curr_q_cnt
sparse_m_block = curr_full_idx[full_iter]
is_full_block = True

m_block = sparse_m_block * subtile_factor + subtile_offset

return m_block, is_full_block
23 changes: 23 additions & 0 deletions flash_attn/cute/block_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,29 @@ def get_block_sparse_expected_shapes(
return expected_count_shape, expected_index_shape


def get_block_sparse_expected_shapes_bwd(
batch_size: int,
num_head: int,
seqlen_q: int,
seqlen_k: int,
m_block_size: int,
n_block_size: int,
subtile_factor: int,
) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:
"""Return (expected_count_shape, expected_index_shape) for backward block sparse normalization.

Backward uses Q-direction indexing (transposed from forward), where shapes are
indexed by N-blocks first, then M-blocks. The sparse_block_size_q is determined
by subtile_factor * m_block_size.
"""
sparse_block_size_q = subtile_factor * m_block_size
expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q)
expected_n_blocks = ceildiv(seqlen_k, n_block_size)
expected_count_shape = (batch_size, num_head, expected_n_blocks)
expected_index_shape = (batch_size, num_head, expected_n_blocks, expected_m_blocks)
return expected_count_shape, expected_index_shape


def normalize_block_sparse_tensors(
tensors: BlockSparseTensorsTorch,
*,
Expand Down
Loading