diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index 706e3d6ad2..bcc957bffb 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -493,20 +493,31 @@ def produce_block_sparse_loads_sm100( pipeline_kv, q_stage: cutlass.Constexpr, q_producer_phase: Int32, + qhead_per_kvhead: cutlass.Constexpr, ): """SM100 entry point for sparse block iteration. SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use simplified block processing that just calls producer_acquire without extras. + + Args: + m_block: which tile of m we are processing + qhead_per_kvhead: Constexpr pack factor """ + # NB: Compute unpacked index for sparse tensor access + if const_expr(qhead_per_kvhead != 1): + m_block_sparse = m_block // qhead_per_kvhead + else: + m_block_sparse = m_block + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] if const_expr(full_block_cnt is not None): - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] else: curr_full_block_cnt = Int32(0) curr_full_block_idx = None @@ -574,15 +585,22 @@ def get_total_block_count( batch_idx, head_idx, m_block, + qhead_per_kvhead: cutlass.Constexpr, ): + # NB: Convert packed m_block to unpacked for sparse tensor indexing + if const_expr(qhead_per_kvhead != 1): + m_block_sparse = m_block // qhead_per_kvhead + else: + m_block_sparse = m_block + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors if const_expr(full_block_cnt is not None): return ( - mask_block_cnt[batch_idx, head_idx, m_block] - + full_block_cnt[batch_idx, head_idx, m_block] + mask_block_cnt[batch_idx, head_idx, m_block_sparse] + + full_block_cnt[batch_idx, head_idx, m_block_sparse] ) else: - return mask_block_cnt[batch_idx, head_idx, m_block] + return mask_block_cnt[batch_idx, head_idx, m_block_sparse] @cute.jit @@ -717,16 +735,23 @@ def softmax_block_sparse_sm100( mbar_P_full_2_offset: Int32, q_stage: cutlass.Constexpr, stage_idx: Int32, - check_m_boundary: bool = False, + check_m_boundary: bool, + qhead_per_kvhead: cutlass.Constexpr, ): + # Convert packed m_block to unpacked for sparse tensor indexing + if const_expr(qhead_per_kvhead != 1): + m_block_sparse = m_block // qhead_per_kvhead + else: + m_block_sparse = m_block + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] if const_expr(full_block_cnt is not None): - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] else: curr_full_block_cnt = Int32(0) curr_full_block_idx = None diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 701dda997d..ac2bda9103 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1291,6 +1291,7 @@ def load( pipeline_kv, self.q_stage, q_producer_phase, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) @@ -1366,7 +1367,7 @@ def mma( process_tile = False if const_expr(self.use_block_sparsity): - block_iter_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block) + block_iter_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) process_tile = block_iter_count > Int32(0) else: n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) @@ -1674,7 +1675,7 @@ def softmax_loop( softmax.reset() if const_expr(self.use_block_sparsity): - tile_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block) + tile_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) has_work = tile_block_count > Int32(0) else: tile_block_count = n_block_max - n_block_min @@ -1742,6 +1743,7 @@ def softmax_loop( self.q_stage, Int32(stage), check_m_boundary, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) if not empty_tile: sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] @@ -2034,7 +2036,7 @@ def correction_loop( stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage if const_expr(self.use_block_sparsity): - total_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block) + total_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) has_work = total_block_count > Int32(0) else: total_block_count = n_block_max - n_block_min diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 103eb55f5a..3f7ad37b68 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -26,11 +26,6 @@ import torch -@lru_cache(maxsize=None) -def _get_device_capability(): - """Cached device capability check.""" - return torch.cuda.get_device_capability()[0] - import cuda.bindings.driver as cuda import cutlass @@ -55,6 +50,11 @@ def _get_device_capability(): get_block_sparse_expected_shapes_bwd, ) +@lru_cache(maxsize=None) +def _get_device_capability(): + """Cached device capability check.""" + return torch.cuda.get_device_capability()[0] + def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -327,20 +327,18 @@ def _flash_attn_fwd( raise NotImplementedError( "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." ) - if pack_gqa: - raise NotImplementedError( - "mask_mod with aux_tensors is not yet supported with pack_gqa=True. This will be fixed in a future PR." - ) if use_block_sparsity: if is_varlen: raise NotImplementedError( "Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR." ) - if pack_gqa: - raise NotImplementedError( - "Block sparsity is not yet supported with pack_gqa=True. This will be fixed in a future PR." - ) + # NB: pack_gqa requires block sparse head dim == 1 (broadcasted) + if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1: + pack_gqa = False + # SM90 doesn't support pack_gqa + block_sparsity yet + if pack_gqa and compute_capability == 9: + pack_gqa = False if is_split_kv: raise NotImplementedError( "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." @@ -506,7 +504,6 @@ def _flash_attn_fwd( expected_count_shape=expected_count_shape, expected_index_shape=expected_index_shape, ) - _flash_attn_fwd.compile_cache[compile_key]( q, k, diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 0a772fa425..1d92228e97 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -144,8 +144,14 @@ def apply_mask( for r in cutlass.range_constexpr(nrow): global_row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m row_for_mod = global_row_idx + head_idx_for_mod = head_idx + if const_expr(self.qhead_per_kvhead_packgqa != 1): + head_offset = global_row_idx % self.qhead_per_kvhead_packgqa + head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset + row_for_mod = global_row_idx // self.qhead_per_kvhead_packgqa + row_for_seqlen = row_for_mod if const_expr(wrap_aux_indices): - _, row_for_mod = divmod(global_row_idx, fastdiv_mods[0]) + _, row_for_mod = divmod(row_for_mod, fastdiv_mods[0]) for col in cutlass.range_constexpr(ncol): col_idx_local = t0ScS_mn[0, col][1] @@ -156,7 +162,7 @@ def apply_mask( _, col_for_mod = divmod(global_col_idx, fastdiv_mods[1]) batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) - head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) + head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32) q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32) kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32) mask_value = mask_mod( @@ -168,7 +174,7 @@ def apply_mask( ) cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) if const_expr(mask_seqlen): - out_of_bounds = (global_row_idx >= self.seqlen_q) or ( + out_of_bounds = (row_for_seqlen >= self.seqlen_q) or ( global_col_idx >= self.seqlen_k ) if out_of_bounds: @@ -346,26 +352,32 @@ def apply_mask_sm100( and fastdiv_mods[1] is not None ) batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) - head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) - row_coord_first = tScS_t2r[0][0] - global_row = row_coord_first + m_block * self.tile_m - if const_expr(self.qhead_per_kvhead_packgqa != 1): - mask_row = global_row // self.qhead_per_kvhead_packgqa - else: - mask_row = global_row - mask_row_for_mod = mask_row - if const_expr(has_fastdiv and aux_tensors is not None): - if check_q_boundary: - _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) - mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32) ncol = const_expr(cute.size(tScS_t2r.shape)) for i in cutlass.range_constexpr(ncol): + row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1] col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] + global_row = row_coord + m_block * self.tile_m global_col = col_coord + n_block * self.tile_n + + if const_expr(self.qhead_per_kvhead_packgqa != 1): + head_offset = global_row % self.qhead_per_kvhead_packgqa + head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset + mask_row = global_row // self.qhead_per_kvhead_packgqa + else: + head_idx_for_mod = head_idx + mask_row = global_row + + mask_row_for_mod = mask_row + if const_expr(has_fastdiv and aux_tensors is not None): + if check_q_boundary: + _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) global_col_for_mod = global_col if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None): _, global_col_for_mod = divmod(global_col, fastdiv_mods[1]) + + head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32) + mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32) kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32) mask_value = mask_mod( batch_idx_ssa, @@ -379,7 +391,7 @@ def apply_mask_sm100( if const_expr(mask_seqlen): acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i] if check_q_boundary: - acc_S[i] = -Float32.inf if global_row >= self.seqlen_q else acc_S[i] + acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i] else: # Causal or local causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q diff --git a/flash_attn/cute/mask_definitions.py b/flash_attn/cute/mask_definitions.py index 546adf17f3..8f2e4b33cc 100644 --- a/flash_attn/cute/mask_definitions.py +++ b/flash_attn/cute/mask_definitions.py @@ -219,21 +219,22 @@ def cute_ima_mask( def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): + """Generate synthetic document ids shared across heads.""" doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device) for b in range(batch): + N = seqlen_q + max_segments = max(1, math.ceil(math.sqrt(max(N // 4, 1)))) + n = random.randint(1, max_segments) + n = min(n, N) + cuts = sorted(random.sample(range(1, N), n - 1)) + lengths = [b - a for a, b in zip((0, *cuts), (*cuts, N))] + base_doc_ids = torch.repeat_interleave( + torch.arange(len(lengths), device=device, dtype=torch.int32), + torch.tensor(lengths, device=device, dtype=torch.int32), + ) + for h in range(nheads): - N = seqlen_q - max_segments = max(1, math.ceil(math.sqrt(max(N // 4, 1)))) - n = random.randint(1, max_segments) - n = min(n, N) - cuts = sorted(random.sample(range(1, N), n - 1)) - lengths = [b - a for a, b in zip((0, *cuts), (*cuts, N))] - - doc_ids = [] - for i, length in enumerate(lengths): - doc_ids += [i for _ in range(length)] - - doc_ids_tensor[b, h, :] = torch.tensor(doc_ids, dtype=torch.int32, device=device) + doc_ids_tensor[b, h, :] = base_doc_ids return doc_ids_tensor diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index f40304e6c5..f39975be59 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -162,10 +162,15 @@ def _run_mask_test( # Determine nheads_kv based on mode if kv_mode == "mha": nheads_kv = nheads + pack_gqa = False elif kv_mode == "gqa": - nheads_kv = nheads // 2 + if COMPUTE_CAPABILITY != 10: + pytest.skip("pack_gqa requires SM100") + nheads_kv = nheads // 4 + pack_gqa = True elif kv_mode == "mqa": nheads_kv = 1 + pack_gqa = False else: raise ValueError(f"Unknown kv_mode: {kv_mode}") @@ -211,10 +216,11 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): else: sparse_tile_m = tile_m + block_mask_nheads = 1 if pack_gqa else nheads bm = create_block_mask( mask_mod_flex, batch_size, - nheads, + block_mask_nheads, seqlen_q, seqlen_k, device="cuda", @@ -270,8 +276,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): learnable_sink=None, m_block_size=tile_m, n_block_size=tile_n, - num_threads=384, - pack_gqa=False, + pack_gqa=pack_gqa, _compute_capability=None, score_mod=None, mask_mod=mask_mod_cute, @@ -626,7 +631,7 @@ def test_static_masks( @pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS_SMOKE) @pytest.mark.parametrize("nheads", [16]) -@pytest.mark.parametrize("kv_mode", ["mha"]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa"]) @pytest.mark.parametrize("headdim", [128]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("use_block_sparsity", [True, False])