diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index c4aad2cd58a..fe1c4cea812 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -115,6 +115,17 @@ def finish_overlap_v_load( return kv_producer_state +@cute.jit +def sparse_tensor_m_block( + m_block, + qhead_per_kvhead: cutlass.Constexpr[int], +): + """Map packed m_block indices to block-sparse tensor indices.""" + if const_expr(qhead_per_kvhead != 1): + return m_block // qhead_per_kvhead + return m_block + + @cute.jit def produce_block_sparse_loads( blocksparse_tensors: BlockSparseTensors, @@ -130,6 +141,7 @@ def produce_block_sparse_loads( use_tma_q: cutlass.Constexpr, tma_q_bytes: cutlass.Constexpr, intra_wg_overlap: cutlass.Constexpr, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, ): """Iterate over the mask and full block lists for a single tile. @@ -141,16 +153,21 @@ def produce_block_sparse_loads( while we advance the producer state to start the next full K. Either the full list overlaps that pending V load, or, if no full blocks exist, we explicitly drain it. + Args: + qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and + must be converted to unpacked for sparse tensor indexing. """ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead) + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] if const_expr(full_block_cnt is not None): - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] else: curr_full_block_cnt = Int32(0) curr_full_block_idx = None @@ -290,18 +307,26 @@ def consume_block_sparse_loads( intra_wg_overlap: cutlass.Constexpr, warp_scheduler_barrier_sync: Callable, warp_scheduler_barrier_arrive: Callable, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, ): """Consume the mask and full block lists for a single tile on the consumer side. - Mirrors `produce_block_sparse_loads` so that the consumer pipeline + Mirrors `produce_block_sparse_loads` so that the consumer pipeline uses + the same sparse tensor indexing. + + Args: + qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and + must be converted to unpacked for sparse tensor indexing. """ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead) + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0 diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index fe72582ebc9..c341d26fbbf 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -1857,6 +1857,7 @@ def load( self.use_tma_Q, self.tma_copy_bytes["Q"], self.intra_wg_overlap, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) tile_scheduler.prefetch_next_work() @@ -2167,6 +2168,7 @@ def mma( self.intra_wg_overlap, self.warp_scheduler_barrier_sync, self.warp_scheduler_barrier_arrive, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) # Handle empty case (when no blocks to process) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 925adf9a194..c902a17bb6e 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -335,9 +335,6 @@ def _flash_attn_fwd( # NB: pack_gqa requires block sparse head dim == 1 (broadcasted) if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1: pack_gqa = False - # SM90 doesn't support pack_gqa + block_sparsity yet - if pack_gqa and compute_capability == 9: - pack_gqa = False if is_split_kv: raise NotImplementedError( "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index bad320fe5ce..847cfe8588a 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -163,8 +163,8 @@ def _run_mask_test( nheads_kv = nheads pack_gqa = False elif kv_mode == "gqa": - if COMPUTE_CAPABILITY != 10: - pytest.xfail("pack_gqa requires SM100") + if COMPUTE_CAPABILITY < 9: + pytest.xfail("pack_gqa requires SM90+") nheads_kv = nheads // 4 pack_gqa = True elif kv_mode == "mqa": @@ -240,7 +240,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): ) = bm.as_tuple() # SM90 block-sparse backward expects BlockMask granularity (128, 128) regardless of fwd tiling. - if COMPUTE_CAPABILITY == 9 and use_block_sparsity: + if COMPUTE_CAPABILITY == 9 and use_block_sparsity and (sparse_tile_m, tile_n) != (128, 128): bm_bwd = create_block_mask( mask_mod_flex, batch_size, @@ -367,7 +367,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" ) - if needs_backward and kv_mode == "mha": + if needs_backward: q = tensors["q"] k = tensors["k"] v = tensors["v"]