Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions flash_attn/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ def finish_overlap_v_load(
return kv_producer_state


@cute.jit
def sparse_tensor_m_block(
m_block,
qhead_per_kvhead: cutlass.Constexpr[int],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would we ever want this to be a fast divmod? we might already have one flying around for gqa right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me look, not a bad idea

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we have one for bwd but not fwd ..

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

we dont have one in the fwd surprisngly, wired one up perf seems a lil negligible will leave as a follow up

):
"""Map packed m_block indices to block-sparse tensor indices."""
if const_expr(qhead_per_kvhead != 1):
return m_block // qhead_per_kvhead
return m_block


@cute.jit
def produce_block_sparse_loads(
blocksparse_tensors: BlockSparseTensors,
Expand All @@ -130,6 +141,7 @@ def produce_block_sparse_loads(
use_tma_q: cutlass.Constexpr,
tma_q_bytes: cutlass.Constexpr,
intra_wg_overlap: cutlass.Constexpr,
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
):
"""Iterate over the mask and full block lists for a single tile.

Expand All @@ -141,16 +153,21 @@ def produce_block_sparse_loads(
while we advance the producer state to start the next full K. Either the full list
overlaps that pending V load, or, if no full blocks exist, we explicitly drain it.

Args:
qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and
must be converted to unpacked for sparse tensor indexing.
"""

mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors

curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None]
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead)

curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]

if const_expr(full_block_cnt is not None):
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block]
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None]
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]
else:
curr_full_block_cnt = Int32(0)
curr_full_block_idx = None
Expand Down Expand Up @@ -290,18 +307,26 @@ def consume_block_sparse_loads(
intra_wg_overlap: cutlass.Constexpr,
warp_scheduler_barrier_sync: Callable,
warp_scheduler_barrier_arrive: Callable,
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
):
"""Consume the mask and full block lists for a single tile on the consumer side.

Mirrors `produce_block_sparse_loads` so that the consumer pipeline
Mirrors `produce_block_sparse_loads` so that the consumer pipeline uses
the same sparse tensor indexing.

Args:
qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and
must be converted to unpacked for sparse tensor indexing.
"""

mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors

curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None]
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block]
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None]
m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead)

curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]
curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]
curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]

processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0

Expand Down
2 changes: 2 additions & 0 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,6 +1857,7 @@ def load(
self.use_tma_Q,
self.tma_copy_bytes["Q"],
self.intra_wg_overlap,
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
)

tile_scheduler.prefetch_next_work()
Expand Down Expand Up @@ -2167,6 +2168,7 @@ def mma(
self.intra_wg_overlap,
self.warp_scheduler_barrier_sync,
self.warp_scheduler_barrier_arrive,
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
)

# Handle empty case (when no blocks to process)
Expand Down
3 changes: 0 additions & 3 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,6 @@ def _flash_attn_fwd(
# NB: pack_gqa requires block sparse head dim == 1 (broadcasted)
if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1:
pack_gqa = False
# SM90 doesn't support pack_gqa + block_sparsity yet
if pack_gqa and compute_capability == 9:
pack_gqa = False
if is_split_kv:
raise NotImplementedError(
"Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split."
Expand Down
8 changes: 4 additions & 4 deletions tests/cute/test_mask_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def _run_mask_test(
nheads_kv = nheads
pack_gqa = False
elif kv_mode == "gqa":
if COMPUTE_CAPABILITY != 10:
pytest.xfail("pack_gqa requires SM100")
if COMPUTE_CAPABILITY < 9:
pytest.xfail("pack_gqa requires SM90+")
nheads_kv = nheads // 4
pack_gqa = True
elif kv_mode == "mqa":
Expand Down Expand Up @@ -240,7 +240,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias):
) = bm.as_tuple()

# SM90 block-sparse backward expects BlockMask granularity (128, 128) regardless of fwd tiling.
if COMPUTE_CAPABILITY == 9 and use_block_sparsity:
if COMPUTE_CAPABILITY == 9 and use_block_sparsity and (sparse_tile_m, tile_n) != (128, 128):
bm_bwd = create_block_mask(
mask_mod_flex,
batch_size,
Expand Down Expand Up @@ -367,7 +367,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias):
f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}"
)

if needs_backward and kv_mode == "mha":
if needs_backward:
q = tensors["q"]
k = tensors["k"]
v = tensors["v"]
Expand Down