diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 2811e7c4551..11f6249ef20 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -708,6 +708,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s set_params_splitkv(params, batch_size, num_heads, head_size, max_seqlen_k, max_seqlen_q, head_size_rounded, p_dropout, num_splits, get_num_sm(get_current_device()), opts); + } else if (paged_KV) { + TORCH_CHECK(num_splits <= 1, "num_splits > 1 is not supported for varlen paged KV"); + params.num_splits = num_splits; } if (leftpad_k_.has_value()) { diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index 422da2b66a0..35bb4365ff6 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -143,8 +143,8 @@ def get_n_block_max_for_m_block( self, seqlen_info: SeqlenInfoQK, m_block: Int32, - n_block_global_max: Int32, ) -> Int32: + n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) if const_expr(self.is_causal or self.window_size_right is not None): m_idx_max = (m_block + 1) * self.tile_m if const_expr(self.qhead_per_kvhead_packgqa > 1): @@ -152,5 +152,5 @@ def get_n_block_max_for_m_block( n_idx_right = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q if const_expr(self.window_size_right is not None): n_idx_right += self.window_size_right - return min(n_block_global_max, cute.ceil_div(n_idx_right, self.tile_n)) - return n_block_global_max + n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.tile_n)) + return n_block_max diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index d664b16dc64..fb131745b3b 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -5,7 +5,7 @@ These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads. """ -from typing import Callable, Optional +from typing import Callable, Optional, Tuple from functools import partial import math import cutlass @@ -17,6 +17,67 @@ # Import data structures from block_sparsity from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.named_barrier import NamedBarrierBwd +from flash_attn.cute.seqlen_info import SeqlenInfoQK + + +@cute.jit +def _get_curr_blocksparse_tensors_varlen( + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + blocksparse_tensors: BlockSparseTensors, + seqlen_info: SeqlenInfoQK, +) -> Tuple[cutlass.Int32, cute.Tensor, cutlass.Int32, Optional[cute.Tensor]]: + """Varlen path: tensors are 2D [nheads, total_m_blocks] / [nheads, total_n_blocks].""" + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors + curr_m_block = seqlen_info.m_block_offset + m_block + curr_block_idx_offset = seqlen_info.block_idx_offset + m_block * seqlen_info.num_n_blocks + curr_mask_block_cnt = mask_block_cnt[head_idx, curr_m_block] + curr_mask_block_idx = cute.domain_offset(curr_block_idx_offset, mask_block_idx[head_idx, None]) + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[head_idx, curr_m_block] + curr_full_block_idx = cute.domain_offset( + curr_block_idx_offset, full_block_idx[head_idx, None] + ) + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + return (curr_mask_block_cnt, curr_mask_block_idx, curr_full_block_cnt, curr_full_block_idx) + + +@cute.jit +def _get_curr_blocksparse_tensors( + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + blocksparse_tensors: BlockSparseTensors, +) -> Tuple[cutlass.Int32, cute.Tensor, cutlass.Int32, Optional[cute.Tensor]]: + """Fixed-length path: tensors are 4D [batch, nheads, m_block, n_block].""" + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + return (curr_mask_block_cnt, curr_mask_block_idx, curr_full_block_cnt, curr_full_block_idx) + + +@cute.jit +def get_curr_blocksparse_tensors( + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + blocksparse_tensors: BlockSparseTensors, + seqlen_info: SeqlenInfoQK, +) -> Tuple[cutlass.Int32, cute.Tensor, cutlass.Int32, Optional[cute.Tensor]]: + """Extract head, m_block, and batch-local blocksparsity data from blocksparse_tensors""" + if const_expr(len(blocksparse_tensors.mask_block_cnt.shape) == 2): + return _get_curr_blocksparse_tensors_varlen( + head_idx, m_block, blocksparse_tensors, seqlen_info + ) + return _get_curr_blocksparse_tensors(batch_idx, head_idx, m_block, blocksparse_tensors) # NOTE [SM100 block-sparse empty tiles: mbarrier contract] @@ -162,6 +223,7 @@ def produce_block_sparse_loads( batch_idx, head_idx, m_block, + seqlen_info: SeqlenInfoQK, kv_producer_state, load_K, load_V, @@ -187,20 +249,20 @@ def produce_block_sparse_loads( qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and must be converted to unpacked for sparse tensor indexing. """ - - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors - m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] - - if const_expr(full_block_cnt is not None): - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] - else: - curr_full_block_cnt = Int32(0) - curr_full_block_idx = None + ( + curr_mask_block_cnt, + curr_mask_block_idx, + curr_full_block_cnt, + curr_full_block_idx, + ) = get_curr_blocksparse_tensors( + batch_idx, + head_idx, + m_block_sparse, + blocksparse_tensors, + seqlen_info, + ) mask_empty = curr_mask_block_cnt == 0 full_empty = curr_full_block_cnt == 0 @@ -305,7 +367,7 @@ def consume_block_sparse_loads( batch_idx, head_idx, m_block, - seqlen, + seqlen_info, kv_consumer_state, mma_pv_fn, mma_one_n_block, @@ -331,15 +393,20 @@ def consume_block_sparse_loads( qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and must be converted to unpacked for sparse tensor indexing. """ - - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors - m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] + ( + curr_mask_block_cnt, + curr_mask_block_idx, + curr_full_block_cnt, + curr_full_block_idx, + ) = get_curr_blocksparse_tensors( + batch_idx, + head_idx, + m_block_sparse, + blocksparse_tensors, + seqlen_info, + ) processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0 @@ -420,7 +487,7 @@ def consume_block_sparse_loads( mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] kv_consumer_state = process_first_half_block( n_block=mask_n_block, - seqlen=seqlen, + seqlen=seqlen_info, kv_consumer_state=kv_consumer_state, mask_fn=partial( mask_fn, @@ -436,7 +503,7 @@ def consume_block_sparse_loads( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=mask_n_block, - seqlen=seqlen, + seqlen=seqlen_info, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), ) @@ -447,7 +514,7 @@ def consume_block_sparse_loads( if curr_mask_block_cnt == 0: kv_consumer_state = process_first_half_block( n_block=full_n_block, - seqlen=seqlen, + seqlen=seqlen_info, kv_consumer_state=kv_consumer_state, mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), score_mod_fn=score_mod_fn, @@ -457,7 +524,7 @@ def consume_block_sparse_loads( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, - seqlen=seqlen, + seqlen=seqlen_info, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), ) @@ -467,7 +534,7 @@ def consume_block_sparse_loads( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, - seqlen=seqlen, + seqlen=seqlen_info, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), ) @@ -483,10 +550,20 @@ def consume_block_sparse_loads( return kv_consumer_state, O_should_accumulate, processed_any +@cute.jit +def split_block_range(block_count, split_idx: Int32, num_splits: Int32): + """Return the half-open block-list range assigned to one SplitKV partition.""" + blocks_per_split = cute.ceil_div(block_count, num_splits) + block_begin = cutlass.min(split_idx * blocks_per_split, block_count) + block_end = cutlass.min(block_begin + blocks_per_split, block_count) + return block_begin, block_end + + @cute.jit def load_block_list_sm100( block_indices: cute.Tensor, - block_count, + block_begin, + block_end, load_q_with_first: cutlass.Constexpr, q_stage: cutlass.Constexpr, kv_producer_state, @@ -496,9 +573,10 @@ def load_block_list_sm100( pipeline_kv, ): """SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count).""" + block_count = block_end - block_begin if block_count > 0: # First iteration: load Q alongside K if requested - n_block_first = block_indices[block_count - 1] + n_block_first = block_indices[block_end - 1] if const_expr(load_q_with_first): # SM100 loads Q0 and optionally Q1 @@ -515,7 +593,7 @@ def load_block_list_sm100( # Remaining blocks for offset in cutlass.range(1, block_count): - n_block = block_indices[block_count - 1 - offset] + n_block = block_indices[block_end - 1 - offset] load_K(block=n_block, producer_state=kv_producer_state, page_idx=None) kv_producer_state.advance() load_V(block=n_block, producer_state=kv_producer_state, page_idx=None) @@ -531,6 +609,9 @@ def produce_block_sparse_loads_sm100( batch_idx, head_idx, m_block, + seqlen_info: SeqlenInfoQK, + split_idx: Int32, + num_splits: Int32, kv_producer_state, load_Q, load_K, @@ -552,20 +633,23 @@ def produce_block_sparse_loads_sm100( """ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors - - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] - - if const_expr(full_block_cnt is not None): - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] - else: - curr_full_block_cnt = Int32(0) - curr_full_block_idx = None + ( + curr_mask_block_cnt, + curr_mask_block_idx, + curr_full_block_cnt, + curr_full_block_idx, + ) = get_curr_blocksparse_tensors( + batch_idx, + head_idx, + m_block_sparse, + blocksparse_tensors, + seqlen_info, + ) - mask_empty = curr_mask_block_cnt == 0 - full_empty = curr_full_block_cnt == 0 + mask_begin, mask_end = split_block_range(curr_mask_block_cnt, split_idx, num_splits) + full_begin, full_end = split_block_range(curr_full_block_cnt, split_idx, num_splits) + mask_empty = mask_begin == mask_end + full_empty = full_begin == full_end q_phase_flipped = False @@ -573,7 +657,8 @@ def produce_block_sparse_loads_sm100( # No masked blocks: process full list with Q loading kv_producer_state = load_block_list_sm100( curr_full_block_idx, - curr_full_block_cnt, + full_begin, + full_end, load_q_with_first=True, q_stage=q_stage, kv_producer_state=kv_producer_state, @@ -587,7 +672,8 @@ def produce_block_sparse_loads_sm100( # Process masked blocks with Q loading kv_producer_state = load_block_list_sm100( curr_mask_block_idx, - curr_mask_block_cnt, + mask_begin, + mask_end, load_q_with_first=True, q_stage=q_stage, kv_producer_state=kv_producer_state, @@ -602,7 +688,8 @@ def produce_block_sparse_loads_sm100( # Process full blocks without Q loading kv_producer_state = load_block_list_sm100( curr_full_block_idx, - curr_full_block_cnt, + full_begin, + full_end, load_q_with_first=False, q_stage=q_stage, kv_producer_state=kv_producer_state, @@ -624,19 +711,29 @@ def get_total_block_count( batch_idx, head_idx, m_block, + split_idx: Int32, + num_splits: Int32, qhead_per_kvhead: cutlass.Constexpr, q_subtile_factor: cutlass.Constexpr, + seqlen_info: SeqlenInfoQK, ): m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) + ( + curr_mask_block_cnt, + _, + curr_full_block_cnt, + _, + ) = get_curr_blocksparse_tensors( + batch_idx, + head_idx, + m_block_sparse, + blocksparse_tensors, + seqlen_info, + ) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors - if const_expr(full_block_cnt is not None): - return ( - mask_block_cnt[batch_idx, head_idx, m_block_sparse] - + full_block_cnt[batch_idx, head_idx, m_block_sparse] - ) - else: - return mask_block_cnt[batch_idx, head_idx, m_block_sparse] + mask_begin, mask_end = split_block_range(curr_mask_block_cnt, split_idx, num_splits) + full_begin, full_end = split_block_range(curr_full_block_cnt, split_idx, num_splits) + return mask_end - mask_begin + full_end - full_begin @cute.jit @@ -649,7 +746,7 @@ def handle_block_sparse_empty_tile_correction_sm100( is_split_kv: cutlass.Constexpr, learnable_sink, mLSE, - seqlen, + seqlen_info, m_block: Int32, head_idx: Int32, batch_idx: Int32, @@ -737,7 +834,7 @@ def handle_block_sparse_empty_tile_correction_sm100( tidx, stage, m_block, - seqlen.seqlen_q, + seqlen_info.seqlen_q, Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs sO[None, None, stage], mO_cur, @@ -763,6 +860,9 @@ def softmax_block_sparse_sm100( batch_idx, head_idx, m_block, + seqlen_info: SeqlenInfoQK, + split_idx: Int32, + num_splits: Int32, softmax_step: Callable, mask_fn: Callable, mask_fn_none: Callable, @@ -780,25 +880,30 @@ def softmax_block_sparse_sm100( warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors - - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] - - if const_expr(full_block_cnt is not None): - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] - else: - curr_full_block_cnt = Int32(0) - curr_full_block_idx = None + ( + curr_mask_block_cnt, + curr_mask_block_idx, + curr_full_block_cnt, + curr_full_block_idx, + ) = get_curr_blocksparse_tensors( + batch_idx, + head_idx, + m_block_sparse, + blocksparse_tensors, + seqlen_info, + ) - total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt + mask_begin, mask_end = split_block_range(curr_mask_block_cnt, split_idx, num_splits) + full_begin, full_end = split_block_range(curr_full_block_cnt, split_idx, num_splits) + split_mask_block_cnt = mask_end - mask_begin + split_full_block_cnt = full_end - full_begin + total_block_cnt = split_mask_block_cnt + split_full_block_cnt if total_block_cnt == 0: sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx) else: - if curr_mask_block_cnt > 0: - mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + if split_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[mask_end - 1] ( mma_si_consumer_phase, si_corr_producer_phase, @@ -811,8 +916,8 @@ def softmax_block_sparse_sm100( is_first=True, mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary), ) - for i in cutlass.range(1, curr_mask_block_cnt): - mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + for i in cutlass.range(1, split_mask_block_cnt): + mask_n_block = curr_mask_block_idx[mask_end - 1 - i] ( mma_si_consumer_phase, si_corr_producer_phase, @@ -825,9 +930,9 @@ def softmax_block_sparse_sm100( mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary), ) - if curr_full_block_cnt > 0: - full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] - if curr_mask_block_cnt == 0: + if split_full_block_cnt > 0: + full_n_block = curr_full_block_idx[full_end - 1] + if split_mask_block_cnt == 0: ( mma_si_consumer_phase, si_corr_producer_phase, @@ -854,11 +959,11 @@ def softmax_block_sparse_sm100( full_n_block, is_first=False, mask_fn=partial( - mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary + mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary ), ) - for i in cutlass.range(1, curr_full_block_cnt): - full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + for i in cutlass.range(1, split_full_block_cnt): + full_n_block = curr_full_block_idx[full_end - 1 - i] ( mma_si_consumer_phase, si_corr_producer_phase, diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index 4a5726b7493..8d28edae1dc 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -17,17 +17,23 @@ def ceildiv(a: int, b: int) -> int: class BlockSparseTensors(NamedTuple): mask_block_cnt: cute.Tensor mask_block_idx: cute.Tensor - full_block_cnt: cute.Tensor | None - full_block_idx: cute.Tensor | None + full_block_cnt: cute.Tensor | None = None + full_block_idx: cute.Tensor | None = None + cu_total_m_blocks: cute.Tensor | None = None + cu_block_idx_offsets: cute.Tensor | None = None dq_write_order: cute.Tensor | None = None dq_write_order_full: cute.Tensor | None = None def __new_from_mlir_values__(self, values): - if len(values) == 2: - values = (*values, None, None, None, None) - elif len(values) == 4: - values = (*values, None, None) - return BlockSparseTensors(*values) + new_fields = [] + idx = 0 + for original in self: + if original is None: + new_fields.append(None) + else: + new_fields.append(values[idx]) + idx += 1 + return BlockSparseTensors(*new_fields) class BlockSparseTensorsTorch(NamedTuple): @@ -35,6 +41,8 @@ class BlockSparseTensorsTorch(NamedTuple): mask_block_idx: torch.Tensor full_block_cnt: torch.Tensor | None = None full_block_idx: torch.Tensor | None = None + cu_total_m_blocks: torch.Tensor | None = None + cu_block_idx_offsets: torch.Tensor | None = None block_size: tuple[int, int] | None = None dq_write_order: torch.Tensor | None = None dq_write_order_full: torch.Tensor | None = None @@ -214,8 +222,8 @@ def _check_and_expand_block( name: str, cnt: torch.Tensor | None, idx: torch.Tensor | None, - expected_count_shape: Tuple[int, int, int], - expected_index_shape: Tuple[int, int, int, int], + expected_count_shape: Tuple[int, ...], + expected_index_shape: Tuple[int, ...], context: str | None, hint: str | Callable[[], str] | None, ) -> Tuple[torch.Tensor | None, torch.Tensor | None]: @@ -402,8 +410,8 @@ def get_block_sparse_expected_shapes_bwd( def normalize_block_sparse_tensors( tensors: BlockSparseTensorsTorch, *, - expected_count_shape: Tuple[int, int, int], - expected_index_shape: Tuple[int, int, int, int], + expected_count_shape: Tuple[int, ...], + expected_index_shape: Tuple[int, ...], context: str | None = None, hint: str | Callable[[], str] | None = None, ) -> BlockSparseTensorsTorch: @@ -461,6 +469,8 @@ def normalize_block_sparse_tensors( mask_block_idx=mask_idx, full_block_cnt=full_cnt, full_block_idx=full_idx, + cu_total_m_blocks=tensors.cu_total_m_blocks, + cu_block_idx_offsets=tensors.cu_block_idx_offsets, block_size=tensors.block_size, dq_write_order=dq_write_order, dq_write_order_full=dq_write_order_full, @@ -516,6 +526,13 @@ def normalize_block_sparse_config( block_size: tuple[int, int], q_stage: int, ) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]: + """Validate the block-sparse config, infer expected shapes, and normalize. + + Handles both fixed-length (3D `[B, H, M]` / 4D `[B, H, M, N]`) and varlen + (2D `[H, total_m_blocks]` / `[H, total_n_blocks]`) layouts. Varlen is + detected by `tensors.cu_total_m_blocks is not None` and forces + `q_subtile_factor == 1` (TODO: potentially remove this restriction). + """ m_block_size, n_block_size = block_size if tensors.block_size is None: sparse_block_size_q, sparse_block_size_kv = None, n_block_size @@ -525,21 +542,34 @@ def normalize_block_sparse_config( raise ValueError( f"Block sparsity requires sparse_block_size[1]={n_block_size} to match tile_n." ) - expected_count_shape, expected_index_shape, q_subtile_factor = ( - infer_block_sparse_expected_shapes( - tensors, - batch_size=batch_size, - num_head=num_head, - seqlen_q=seqlen_q, - seqlen_k=seqlen_k, - m_block_size=m_block_size, - n_block_size=n_block_size, - q_stage=q_stage, - context="forward", - sparse_block_size_q=sparse_block_size_q, - sparse_block_size_kv=sparse_block_size_kv, + if tensors.cu_total_m_blocks is not None: + base_m_block = q_stage * m_block_size + if sparse_block_size_q is not None and sparse_block_size_q != base_m_block: + raise ValueError( + f"Varlen block sparsity requires sparse_block_size[0]={base_m_block} " + f"(= q_stage * tile_m); got {sparse_block_size_q}." + ) + total_m_blocks = tensors.mask_block_cnt.shape[-1] + total_n_blocks = tensors.mask_block_idx.shape[-1] + expected_count_shape = (num_head, total_m_blocks) + expected_index_shape = (num_head, total_n_blocks) + q_subtile_factor = 1 + else: + expected_count_shape, expected_index_shape, q_subtile_factor = ( + infer_block_sparse_expected_shapes( + tensors, + batch_size=batch_size, + num_head=num_head, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + m_block_size=m_block_size, + n_block_size=n_block_size, + q_stage=q_stage, + context="forward", + sparse_block_size_q=sparse_block_size_q, + sparse_block_size_kv=sparse_block_size_kv, + ) ) - ) normalized_tensors = normalize_block_sparse_tensors( tensors, expected_count_shape=expected_count_shape, @@ -615,6 +645,12 @@ def to_cute_block_sparse_tensors( else None for t in (tensors.full_block_cnt, tensors.full_block_idx) ] + cu_total_m_blocks_tensor, cu_block_idx_offsets_tensor = [ + to_cute_tensor(t, assumed_align=4, leading_dim=0, enable_tvm_ffi=enable_tvm_ffi) + if t is not None + else None + for t in (tensors.cu_total_m_blocks, tensors.cu_block_idx_offsets) + ] dq_write_order_tensor, dq_write_order_full_tensor = [ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) if t is not None @@ -627,6 +663,8 @@ def to_cute_block_sparse_tensors( mask_block_idx_tensor, full_block_cnt_tensor, full_block_idx_tensor, + cu_total_m_blocks_tensor, + cu_block_idx_offsets_tensor, dq_write_order_tensor, dq_write_order_full_tensor, ) diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py index 69e8309a028..777d3613eb1 100644 --- a/flash_attn/cute/compute_block_sparsity.py +++ b/flash_attn/cute/compute_block_sparsity.py @@ -11,7 +11,19 @@ BlockSparseTensorsTorch, to_cute_block_sparse_tensors, ) -from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar +from flash_attn.cute.block_sparse_utils import get_curr_blocksparse_tensors +from flash_attn.cute.testing import is_fake_mode +from flash_attn.cute.cute_dsl_utils import ( + to_cute_tensor, + get_aux_tensor_metadata, + to_cute_aux_tensor, +) +from flash_attn.cute.utils import ( + hash_callable, + scalar_to_ssa, + ssa_to_scalar, + get_batch_from_cu_tensor, +) from flash_attn.cute.seqlen_info import SeqlenInfoQK @@ -28,7 +40,6 @@ class BlockSparsityKernel: TODO: - optimize mask_mod evaluation - - varlen support - transposed tensors for bwd pass """ @@ -52,18 +63,31 @@ def __call__( blocksparse_tensors: BlockSparseTensors, seqlen_q: Int32, seqlen_k: Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, ): - self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx, *_ = blocksparse_tensors + mask_cnt, mask_idx, full_cnt, full_idx, mCuTotalMBlocks, mCuBlockIdxOffsets, *_ = ( + blocksparse_tensors + ) + + self.is_varlen_q = const_expr(mCuSeqlensQ is not None) if const_expr(self.compute_full_blocks): - assert self.full_cnt is not None and self.full_idx is not None, ( + assert full_cnt is not None and full_idx is not None, ( "full block tensors must be provided when computing full blocks" ) - - batch_size, num_heads, num_m_blocks, num_n_blocks = self.mask_idx.shape - # launch 1 CTA per m block - grid = [num_m_blocks, num_heads, batch_size] + if const_expr(not self.is_varlen_q): + batch_size, num_heads, num_m_blocks, _ = mask_idx.shape + total_m_blocks = batch_size * num_m_blocks + else: + assert const_expr(mCuTotalMBlocks is not None), ( + "mCuTotalMBlocks must be provided when varlen q" + ) + num_heads, total_m_blocks = mask_cnt.shape # num_m_blocks is total_m_blocks + batch_size = mCuSeqlensQ.shape[0] - 1 if const_expr(self.use_fast_sampling): num_threads = 5 @@ -72,46 +96,46 @@ def __call__( num_threads = self.tile_mn[0] self.num_warps = (num_threads + 32 - 1) // 32 + if const_expr(not self.is_varlen_q): + grid = [num_m_blocks, num_heads, batch_size] + else: + grid = [total_m_blocks, num_heads, 1] + self.kernel( - self.mask_cnt, - self.mask_idx, - self.full_cnt, - self.full_idx, - num_n_blocks, + blocksparse_tensors, seqlen_q, seqlen_k, + batch_size, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + mCuTotalMBlocks, + mCuBlockIdxOffsets, aux_tensors, ).launch(grid=grid, block=[num_threads, 1, 1]) @cute.kernel def kernel( self, - mask_cnt: cute.Tensor, - mask_idx: cute.Tensor, - full_cnt: cute.Tensor, - full_idx: cute.Tensor, - num_n_blocks: Int32, + blocksparse_tensors: BlockSparseTensors, seqlen_q: Int32, seqlen_k: Int32, + batch_size: Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, + mCuBlockIdxOffsets: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, ): tidx, _, _ = cute.arch.thread_idx() warp_idx = cute.arch.warp_idx() lane_id = cute.arch.lane_idx() - m_block, head_idx, batch_idx = cute.arch.block_idx() ssa = partial(scalar_to_ssa, dtype=Int32) - seqlen = SeqlenInfoQK.create( - batch_idx, - seqlen_q, - seqlen_k, - mCuSeqlensQ=None, - mCuSeqlensK=None, - mSeqUsedQ=None, - mSeqUsedK=None, - ) - @cute.struct class SharedStorage: reduction_buffer_smem: cute.struct.Align[ @@ -124,132 +148,154 @@ class SharedStorage: reduction_buffer = storage.reduction_buffer_smem.get_tensor( cute.make_layout((self.num_warps, 2)) ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=seqlen_q, + seqlen_k_static=seqlen_k, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + mCuTotalMBlocks=mCuTotalMBlocks, + mCuBlockIdxOffsets=mCuBlockIdxOffsets, + tile_m=self.tile_mn[0], + tile_n=self.tile_mn[1], + ) + + if const_expr(not self.is_varlen_q): + m_block, head_idx, batch_idx = cute.arch.block_idx() + else: + global_m_block, head_idx, _ = cute.arch.block_idx() + batch_idx = get_batch_from_cu_tensor(global_m_block, mCuTotalMBlocks) + m_block = global_m_block - mCuTotalMBlocks[batch_idx] + + seqlen = SeqlenInfoCls(batch_idx) + seqlen_q = seqlen.seqlen_q + seqlen_k = seqlen.seqlen_k + global_m_block = seqlen.m_block_offset + m_block + + num_n_blocks = (seqlen_k + self.tile_mn[1] - 1) // self.tile_mn[1] + + _, curr_mask_idx, _, curr_full_idx = get_curr_blocksparse_tensors( + batch_idx, head_idx, m_block, blocksparse_tensors, seqlen + ) num_mask_blocks = Int32(0) num_full_blocks = Int32(0) - for n_block in cutlass.range(num_n_blocks, unroll_full=True): - m_base = m_block * self.tile_mn[0] + m_base = m_block * self.tile_mn[0] + if const_expr(self.use_fast_sampling): + # Loop-invariant per-thread q_idx for the 5 sample points + # (tidx 0, 1: top corners; 2, 3: bottom corners; 4: center). + q_idx_sample = m_base + if tidx == 2 or tidx == 3: + q_idx_sample = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1) + elif tidx == 4: + q_idx_sample = m_base + cutlass.min(seqlen_q - m_base, self.tile_mn[0]) // 2 + else: + q_idx_thread = m_base + tidx + thread_in_bounds = Boolean(tidx < self.tile_mn[0] and q_idx_thread < seqlen_q) + + for n_block in cutlass.range(num_n_blocks): n_base = n_block * self.tile_mn[1] if const_expr(self.use_fast_sampling): - # Fast path: 5-point sampling (4 corners + center) - # Clamps OOB indices to nearest in bounds. - thread_result = Boolean(False) - thread_is_valid = Boolean(False) - q_idx = Int32(0) - kv_idx = Int32(0) + # 5-point sampling (4 corners + center). Interior n_blocks + # (n_base + tile_n <= seqlen_k) skip the OOB clamp on the right / + # center samples. + is_interior = (n_base + self.tile_mn[1]) <= seqlen_k + n_right = Int32(0) + n_mid = Int32(0) + if is_interior: + n_right = n_base + self.tile_mn[1] - 1 + n_mid = n_base + self.tile_mn[1] // 2 + else: + n_right = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1) + n_mid = n_base + cutlass.min(seqlen_k - n_base, self.tile_mn[1]) // 2 - if tidx == 0: - # Top-left corner (0, 0); always in bounds - q_idx = m_base - kv_idx = n_base - elif tidx == 1: - # Top-right corner - q_idx = m_base - kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1) - elif tidx == 2: - # Bottom-left corner - q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1) - kv_idx = n_base - elif tidx == 3: - # Bottom-right corner - q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1) - kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1) + kv_idx = n_base + if tidx == 1 or tidx == 3: + kv_idx = n_right elif tidx == 4: - # Center point - q_idx = m_base + (cutlass.min(seqlen_q - m_base, self.tile_mn[0])) // 2 - kv_idx = n_base + (cutlass.min(seqlen_k - n_base, self.tile_mn[1])) // 2 - else: - thread_is_valid = Boolean(False) + kv_idx = n_mid - # Check bounds and determine if this thread has a valid index pair - if tidx < 5 and q_idx < seqlen_q and kv_idx < seqlen_k: + thread_result = Boolean(False) + thread_is_valid = Boolean(False) + if tidx < 5: thread_is_valid = Boolean(True) - q_idx_ssa = ssa(q_idx) - kv_idx_ssa = ssa(kv_idx) thread_result = ssa_to_scalar( self.mask_mod( ssa(batch_idx), ssa(head_idx), - q_idx_ssa, - kv_idx_ssa, + ssa(q_idx_sample), + ssa(kv_idx), seqlen, aux_tensors, ) ) - else: - thread_is_valid = Boolean(False) - # Use vote_any_sync to see if any valid thread found unmasked or masked - # Only count results from threads that checked valid indices has_unmasked = cute.arch.vote_any_sync(thread_result & thread_is_valid) - has_masked = cute.arch.vote_any_sync((Boolean(not thread_result)) & thread_is_valid) + has_masked = cute.arch.vote_any_sync(Boolean(not thread_result) & thread_is_valid) else: - # Full path: check all elements in the block - # Track if this thread's row has any masked or unmasked elements + # Full path. Interior blocks (n_base + tile_n <= seqlen_k) drop the + # per-element bound check; the boundary block (at most one) keeps it. thread_has_unmasked = Boolean(False) thread_has_masked = Boolean(False) - thread_is_valid = Boolean(False) - - # Each thread handles 1 row - q_idx = m_base + tidx kv_idx = Int32(0) - if tidx < self.tile_mn[0] and q_idx < seqlen_q: - thread_is_valid = Boolean(True) - q_idx_ssa = ssa(q_idx) - - # Loop over all columns in this row - for c in cutlass.range(self.tile_mn[1], unroll_full=True): - kv_idx = n_base + c - kv_idx_ssa = ssa(kv_idx) + is_interior = (n_base + self.tile_mn[1]) <= seqlen_k - # Only check elements within valid sequence bounds - if kv_idx < seqlen_k: - # Direct scalar call + if is_interior: + if thread_in_bounds: + for c in cutlass.range(self.tile_mn[1], unroll_full=True): mask_val = ssa_to_scalar( self.mask_mod( ssa(batch_idx), ssa(head_idx), - q_idx_ssa, - kv_idx_ssa, + ssa(q_idx_thread), + ssa(n_base + c), seqlen, aux_tensors, ) ) + thread_has_unmasked |= Boolean(mask_val) + thread_has_masked |= Boolean(not mask_val) + else: + if thread_in_bounds: + for c in cutlass.range(self.tile_mn[1], unroll_full=True): + kv_idx = n_base + c + if kv_idx < seqlen_k: + mask_val = ssa_to_scalar( + self.mask_mod( + ssa(batch_idx), + ssa(head_idx), + ssa(q_idx_thread), + ssa(kv_idx), + seqlen, + aux_tensors, + ) + ) + thread_has_unmasked |= Boolean(mask_val) + thread_has_masked |= Boolean(not mask_val) - # Update tracking flags - if mask_val: - thread_has_unmasked = Boolean(True) - else: - thread_has_masked = Boolean(True) - - # Block-level reduction to combine results across all threads - # Only count votes from threads that checked valid indices - warp_has_unmasked_mask = cute.arch.vote_any_sync( - thread_has_unmasked & thread_is_valid - ) - warp_has_masked_mask = cute.arch.vote_any_sync(thread_has_masked & thread_is_valid) - - # lane 0 writes the ballot mask to shared memory - lane_id = tidx % 32 + warp_unmasked = cute.arch.vote_any_sync(thread_has_unmasked & thread_in_bounds) + warp_masked = cute.arch.vote_any_sync(thread_has_masked & thread_in_bounds) if lane_id == 0: - # Store as Int8 - reduction_buffer[warp_idx, 0] = Int8(1) if warp_has_unmasked_mask else Int8(0) - reduction_buffer[warp_idx, 1] = Int8(1) if warp_has_masked_mask else Int8(0) - + reduction_buffer[warp_idx, 0] = Int8(1) if warp_unmasked else Int8(0) + reduction_buffer[warp_idx, 1] = Int8(1) if warp_masked else Int8(0) cute.arch.sync_threads() - # Thread 0 ORs all warp results together + # Cross-warp OR via warp 0; thread 0 (lane 0 of warp 0) holds the result. has_unmasked = Boolean(False) has_masked = Boolean(False) - if tidx == 0: - for w in cutlass.range(self.num_warps): - if reduction_buffer[w, 0]: - has_unmasked = Boolean(True) - if reduction_buffer[w, 1]: - has_masked = Boolean(True) + if warp_idx == 0: + lane_unmasked = Boolean(False) + lane_masked = Boolean(False) + if lane_id < self.num_warps: + lane_unmasked = reduction_buffer[lane_id, 0] != Int8(0) + lane_masked = reduction_buffer[lane_id, 1] != Int8(0) + has_unmasked = cute.arch.vote_any_sync(lane_unmasked) + has_masked = cute.arch.vote_any_sync(lane_masked) # Only thread 0 updates the output arrays (common to both paths) if tidx == 0: @@ -261,17 +307,23 @@ class SharedStorage: is_full = Boolean(has_unmasked and (not has_masked)) if is_partial: - mask_idx[batch_idx, head_idx, m_block, num_mask_blocks] = n_block + curr_mask_idx[num_mask_blocks] = n_block num_mask_blocks += 1 elif is_full and const_expr(self.compute_full_blocks): - full_idx[batch_idx, head_idx, m_block, num_full_blocks] = n_block + curr_full_idx[num_full_blocks] = n_block num_full_blocks += 1 # Only thread 0 writes back the counts if tidx == 0: - mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks - if const_expr(self.compute_full_blocks): - full_cnt[batch_idx, head_idx, m_block] = num_full_blocks + mask_cnt, _, full_cnt, *_ = blocksparse_tensors + if const_expr(self.is_varlen_q): + mask_cnt[head_idx, global_m_block] = num_mask_blocks + if const_expr(self.compute_full_blocks): + full_cnt[head_idx, global_m_block] = num_full_blocks + else: + mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks + if const_expr(self.compute_full_blocks): + full_cnt[batch_idx, head_idx, m_block] = num_full_blocks def compute_block_sparsity( @@ -282,11 +334,17 @@ def compute_block_sparsity( seqlen_q, seqlen_k, mask_mod: Callable, - aux_tensors: Optional[list], # list[cute.Tensor] + aux_tensors: Optional[list], device, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + cu_total_m_blocks: Optional[torch.Tensor] = None, + cu_block_idx_offsets: Optional[torch.Tensor] = None, compute_full_blocks: bool = True, use_fast_sampling: bool = False, -) -> Tuple[BlockSparseTensors, BlockSparseTensorsTorch]: +) -> BlockSparseTensorsTorch: """ Computes block sparsity for a given `mask_mod`. @@ -300,59 +358,141 @@ def compute_block_sparsity( mask_mod: The `mask_mod` callable to use. aux_tensors: A list of auxiliary tensors. device: The device to use. + cu_seqlens_q: Cumulative q sequence lengths for varlen + cu_seqlens_k: Cumulative k sequence lengths for varlen + seqused_q: Per-batch effective q sequence lengths + seqused_k: Per-batch effective k sequence lengths + cu_total_m_blocks: Cumulative total m blocks tensor for varlen q + cu_block_idx_offsets: Cumulative offsets into the packed mask_block_idx / + full_block_idx tensors per batch (== cumsum of M_b * N_b). compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed. use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient. Returns: - A tuple of `BlockSparseTensors` and `BlockSparseTensorsTorch`. + BlockSparseTensorsTorch """ - # Check if mask_mod is marked as suitable for 5-point fast sampling + # Check if mask_mod is marked as suitable for 5-point sampling use_fast_sampling = getattr(mask_mod, "use_fast_sampling", use_fast_sampling) num_m_blocks = (seqlen_q + tile_m - 1) // tile_m num_n_blocks = (seqlen_k + tile_n - 1) // tile_n - mask_block_cnt = torch.zeros( - (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32 - ) - mask_block_idx = torch.zeros( - (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 - ) - full_block_cnt = ( - torch.zeros((batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32) - if compute_full_blocks - else None - ) - full_block_idx = ( - torch.zeros( + if cu_seqlens_q is not None: + assert cu_total_m_blocks is not None, "total m blocks must be provided when varlen q" + total_m_blocks = cu_total_m_blocks[-1].item() + if cu_block_idx_offsets is None and (cu_seqlens_k is not None or seqused_k is not None): + # Derive cu_block_idx_offsets from per-batch K seqlens. + cu_block_idx_offsets_list = [0] + for batch_idx in range(batch_size): + batch_seqlen_q = cu_seqlens_q[batch_idx + 1].item() - cu_seqlens_q[batch_idx].item() + if cu_seqlens_k is not None: + batch_seqlen_k = ( + cu_seqlens_k[batch_idx + 1].item() - cu_seqlens_k[batch_idx].item() + ) + else: + batch_seqlen_k = seqused_k[batch_idx].item() + num_m_blocks_batch = (batch_seqlen_q + tile_m - 1) // tile_m + num_n_blocks_batch = (batch_seqlen_k + tile_n - 1) // tile_n + cu_block_idx_offsets_list.append( + cu_block_idx_offsets_list[-1] + num_m_blocks_batch * num_n_blocks_batch + ) + cu_block_idx_offsets = torch.tensor( + cu_block_idx_offsets_list, dtype=torch.int32, device=device + ) + if cu_block_idx_offsets is not None: + total_n_blocks = cu_block_idx_offsets[-1].item() + else: + # Uniform-K varlen-Q: every batch has the same K seqlen. + total_n_blocks = total_m_blocks * num_n_blocks + + mask_block_cnt = torch.zeros((num_heads, total_m_blocks), device=device, dtype=torch.int32) + mask_block_idx = torch.zeros((num_heads, total_n_blocks), device=device, dtype=torch.int32) + full_block_cnt = ( + torch.zeros((num_heads, total_m_blocks), device=device, dtype=torch.int32) + if compute_full_blocks + else None + ) + full_block_idx = ( + torch.zeros((num_heads, total_n_blocks), device=device, dtype=torch.int32) + if compute_full_blocks + else None + ) + else: + total_m_blocks = batch_size * num_m_blocks + total_n_blocks = batch_size * num_m_blocks * num_n_blocks + + mask_block_cnt = torch.zeros( + (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + mask_block_idx = torch.zeros( (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 ) - if compute_full_blocks - else None - ) + full_block_cnt = ( + torch.zeros((batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32) + if compute_full_blocks + else None + ) + full_block_idx = ( + torch.zeros( + (batch_size, num_heads, num_m_blocks, num_n_blocks), + device=device, + dtype=torch.int32, + ) + if compute_full_blocks + else None + ) blocksparse_tensors_torch = BlockSparseTensorsTorch( mask_block_cnt=mask_block_cnt, mask_block_idx=mask_block_idx, full_block_cnt=full_block_cnt, full_block_idx=full_block_idx, + cu_total_m_blocks=cu_total_m_blocks, + cu_block_idx_offsets=cu_block_idx_offsets, block_size=(tile_m, tile_n), ) mask_mod_hash = hash_callable(mask_mod) - blocksparse_tensors = to_cute_block_sparse_tensors( - blocksparse_tensors_torch, enable_tvm_ffi=True - ) + if aux_tensors is not None: + aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) + else: + aux_tensor_metadata = None compile_key = ( tile_m, tile_n, mask_mod_hash, + aux_tensor_metadata, compute_full_blocks, + cu_seqlens_q is None, + cu_seqlens_k is None, + seqused_q is None, + seqused_k is None, aux_tensors is not None, use_fast_sampling, ) if compile_key not in compute_block_sparsity.compile_cache: + ( + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + ) = [ + to_cute_tensor(t, assumed_align=4, leading_dim=0) if t is not None else None + for t in ( + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + ) + ] + blocksparse_tensors = to_cute_block_sparse_tensors( + blocksparse_tensors_torch, enable_tvm_ffi=True + ) + if aux_tensors is not None: + cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] + else: + cute_aux_tensors = None kernel = BlockSparsityKernel( mask_mod, tile_mn=(tile_m, tile_n), @@ -362,24 +502,40 @@ def compute_block_sparsity( ) compute_block_sparsity.compile_cache[compile_key] = cute.compile( - kernel, blocksparse_tensors, seqlen_q, seqlen_k, aux_tensors, options="--enable-tvm-ffi" + kernel, + blocksparse_tensors, + seqlen_q, + seqlen_k, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + cute_aux_tensors, + options="--enable-tvm-ffi", ) - compute_block_sparsity.compile_cache[compile_key]( - ( - blocksparse_tensors_torch.mask_block_cnt, - blocksparse_tensors_torch.mask_block_idx, - blocksparse_tensors_torch.full_block_cnt, - blocksparse_tensors_torch.full_block_idx, - blocksparse_tensors_torch.dq_write_order, - blocksparse_tensors_torch.dq_write_order_full, - ), - seqlen_q, - seqlen_k, - aux_tensors, - ) + if not is_fake_mode(): + compute_block_sparsity.compile_cache[compile_key]( + ( + blocksparse_tensors_torch.mask_block_cnt, + blocksparse_tensors_torch.mask_block_idx, + blocksparse_tensors_torch.full_block_cnt, + blocksparse_tensors_torch.full_block_idx, + blocksparse_tensors_torch.cu_total_m_blocks, + blocksparse_tensors_torch.cu_block_idx_offsets, + blocksparse_tensors_torch.dq_write_order, + blocksparse_tensors_torch.dq_write_order_full, + ), + seqlen_q, + seqlen_k, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + aux_tensors, + ) - return blocksparse_tensors, blocksparse_tensors_torch + return blocksparse_tensors_torch compute_block_sparsity.compile_cache = {} diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 76c856221c5..94f0c88d817 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -63,7 +63,7 @@ def __init__( self.num_threads = num_threads self.AtomLayoutMdQ = AtomLayoutMdQ self.dQ_swapAB = dQ_swapAB - self.use_2cta_instrs = use_2cta_instrs and arch // 10 == 10 and head_dim != 64 + self.use_2cta_instrs = use_2cta_instrs and arch // 10 in [10, 11] and head_dim != 64 self.cluster_size = cluster_size @staticmethod @@ -373,7 +373,7 @@ def kernel( seqlen_q = seqlen.seqlen_q seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) - if const_expr(self.arch // 10 == 10 and self.use_2cta_instrs): + if const_expr(self.arch // 10 in [10, 11] and self.use_2cta_instrs): # 2-CTA: remap dQaccum layout into TMEM view before writing sdQ num_reduce_threads = self.num_threads thr_mma_dsk = tiled_mma.get_slice(tidx) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 9184ddeb029..061ede3d983 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -1010,6 +1010,16 @@ class SharedStorage: min_blocks_per_mp=1, ) + def _generate_attention_mask_cls(self, window_size_left, window_size_right): + return partial( + AttentionMask, + self.tile_m, + self.tile_n * self.cta_group_size, + swap_AB=True, + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + @cute.kernel def kernel( self, @@ -1413,14 +1423,7 @@ def kernel( ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) - AttentionMaskCls = partial( - AttentionMask, - self.tile_m, - self.tile_n * self.cta_group_size, - swap_AB=True, - window_size_left=window_size_left, - window_size_right=window_size_right, - ) + AttentionMaskCls = self._generate_attention_mask_cls(window_size_left, window_size_right) # EMPTY # (15) if warp_idx == self.empty_warp_id: @@ -3432,13 +3435,10 @@ def _dq_semaphore_lock_value( seqlen, m_block: Int32, n_block: Int32, - n_block_global_max: Int32, ) -> Int32: lock_value = n_block if const_expr(self.spt): - n_block_max_for_m_block = block_info.get_n_block_max_for_m_block( - seqlen, m_block, n_block_global_max - ) + n_block_max_for_m_block = block_info.get_n_block_max_for_m_block(seqlen, m_block) lock_value = n_block_max_for_m_block - 1 - n_block if const_expr(self.use_block_sparsity): assert blocksparse_tensors is not None @@ -3528,7 +3528,6 @@ def dQacc_reduce( mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] delay_semaphore_release = not self.tile_hdim == 192 and not self.use_block_sparsity - n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n) curr_q_cnt = Int32(0) curr_q_idx = None @@ -3627,7 +3626,6 @@ def dQacc_reduce( seqlen, m_block, n_block_cta_group, - n_block_global_max, ) barrier.wait_eq( mdQ_semaphore_cur[(m_block, None)].iterator, diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 06564165451..545e40dc04e 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -105,12 +105,18 @@ def __init__( self.mask_mod = mask_mod self.output_quant_key = output_quant_key self.qk_acc_dtype = Float32 - self.vec_size: cutlass.Constexpr = getattr( + self.score_vec_size: cutlass.Constexpr = getattr( score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 ) - if self.vec_size > 2: + if self.score_vec_size > 2: raise ValueError( - f"score_mod vec_size {self.vec_size} not supported on Sm80/90/120 " + f"score_mod vec_size {self.score_vec_size} not supported on Sm80/90/120 " + "due to accumulator thread ownership pattern." + ) + self.mask_vec_size: cutlass.Constexpr = getattr(mask_mod, "__vec_size__", 1) + if self.mask_vec_size > 1: + raise ValueError( + f"mask_mod vec_size {self.mask_vec_size} not supported on Sm80/90/120 " "due to accumulator thread ownership pattern." ) self.arch = BaseDSL._get_dsl().get_arch_enum() @@ -1260,7 +1266,7 @@ def apply_score_mod( batch_idx, head_idx, softmax_scale, - self.vec_size, + self.score_vec_size, self.qk_acc_dtype, aux_tensors, fastdiv_mods, diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py index 75bf77c89b4..84cf365d46a 100644 --- a/flash_attn/cute/flash_fwd_combine.py +++ b/flash_attn/cute/flash_fwd_combine.py @@ -69,7 +69,10 @@ def can_implement( ) -> bool: """Check if the kernel can be implemented with the given parameters.""" if dtype not in [ - cutlass.Float16, cutlass.BFloat16, cutlass.Float32, cutlass.Float8E4M3FN, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float32, + cutlass.Float8E4M3FN, ]: return False if dtype_partial not in [cutlass.Float16, cutlass.BFloat16, Float32]: diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index acf147c0ddd..932504e0041 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -161,7 +161,8 @@ def __init__( assert self.split_P_arrive % 32 == 0 assert self.split_P_arrive < self.n_block_size self.arch = BaseDSL._get_dsl().get_arch_enum() - assert self.arch >= Arch.sm_100 and self.arch <= Arch.sm_110f, "Only SM 10.x and 11.x are supported" + assert self.arch.is_family_of(Arch.sm_100f) or self.arch.is_family_of(Arch.sm_110f), \ + "Only SM 10.x and 11.x are supported" self.cta_group_size = 2 if self.use_2cta_instrs else 1 # cta_tiler M includes only 1 CTA, the scheduler will take into account the cluster shape @@ -188,12 +189,17 @@ def __init__( self.score_mod = score_mod self.mask_mod = mask_mod self.output_quant_key = output_quant_key - self.vec_size: cutlass.Constexpr = getattr( + self.score_vec_size: cutlass.Constexpr = getattr( score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 ) + self.mask_vec_size: cutlass.Constexpr = getattr(mask_mod, "__vec_size__", 1) # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) - is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f + # NOTE: is_family_of also matches any future sm_10x with x > 3 — intentional. + # The flag gates ex2 emulation; sm_103 (B300) has fast hardware ex2 and later + # Blackwell variants are assumed to inherit this, so forward-inclusion is correct + # despite the literal `is_sm103` name. + is_sm103 = self.arch.is_family_of(Arch.sm_103f) self.is_sm103 = is_sm103 # enable_ex2_emu is derived: True if tuning config has freq > 0, else fallback to default logic _default_enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103 @@ -333,7 +339,10 @@ def _setup_attributes(self): smem_size_k_per_stage = self.n_block_size * self.head_dim_padded * self.k_dtype.width // 8 smem_size_v_per_stage = self.n_block_size * self.head_dim_v_padded * self.v_dtype.width // 8 smem_size_kv_per_stage = max(smem_size_k_per_stage, smem_size_v_per_stage) // self.cta_group_size - kv_stage = (224 * 1024 - smem_size_q_o) // smem_size_kv_per_stage + # Cap small head_dim from over-staging: the 224*1024 budget undercounts + # per-stage state, so at hd_padded=16 the unbounded formula picks 52 stages + # and overflows the 227 KB SMEM cap. No-op for hd_padded >= 32 (max 26). + kv_stage = min((224 * 1024 - smem_size_q_o) // smem_size_kv_per_stage, 32) if self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and kv_stage == 2: # For hdim 192,128, we can fit 3 stages if we use uneven_kv_smem kv_stage = 3 @@ -726,6 +735,10 @@ class SharedStorage: self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) if cutlass.const_expr(self.use_block_sparsity and mPageTable is not None): raise NotImplementedError("Block sparsity + paged KV not supported on SM100") + if cutlass.const_expr(self.use_block_sparsity and self.is_varlen_q): + assert const_expr(blocksparse_tensors.cu_total_m_blocks is not None), ( + "blocksparse_tensors.cu_total_m_blocks must be provided for varlen blocksparsity" + ) # Launch the kernel synchronously self.kernel( @@ -773,6 +786,18 @@ class SharedStorage: min_blocks_per_mp=1, ) + def _generate_attention_mask_cls(self, window_size_left, window_size_right): + return partial( + AttentionMask, + self.m_block_size, + self.n_block_size, + window_size_left=window_size_left, + window_size_right=window_size_right, + qhead_per_kvhead_packgqa=( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ), + ) + # GPU device kernel @cute.kernel def kernel( @@ -1055,14 +1080,15 @@ def kernel( mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, + mCuTotalMBlocks=( + blocksparse_tensors.cu_total_m_blocks if blocksparse_tensors is not None else None + ), + mCuBlockIdxOffsets=( + blocksparse_tensors.cu_block_idx_offsets if blocksparse_tensors is not None else None + ), ) - AttentionMaskCls = partial( - AttentionMask, - self.m_block_size, - self.n_block_size, - window_size_left=window_size_left, - window_size_right=window_size_right, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + AttentionMaskCls = self._generate_attention_mask_cls( + window_size_left, window_size_right ) # Cluster wait before tensor memory alloc pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) @@ -1207,6 +1233,7 @@ def kernel( num_splits, SeqlenInfoCls, mma_tile_coord_v, + blocksparse_tensors=blocksparse_tensors, tile_scheduler=tile_scheduler, ) @@ -1495,6 +1522,9 @@ def load( batch_idx, head_idx, m_block, + seqlen, + split_idx, + num_splits, kv_producer_state, load_Q, load_K, @@ -1642,8 +1672,11 @@ def mma( batch_idx, head_idx, m_block, + split_idx, + num_splits, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, + seqlen_info=seqlen, ) process_tile = block_iter_count > Int32(0) else: @@ -1938,6 +1971,7 @@ def softmax_loop( batch_idx=batch_idx, head_idx=head_idx, aux_tensors=aux_tensors, + vec_size=self.mask_vec_size, ) # Recompute fastdiv_mods if necessary @@ -2008,8 +2042,11 @@ def softmax_loop( batch_idx, head_idx, m_block, + split_idx, + num_splits, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, + seqlen_info=seqlen, ) has_work = tile_block_count > Int32(0) else: @@ -2065,6 +2102,9 @@ def softmax_loop( batch_idx, head_idx, m_block, + seqlen, + split_idx, + num_splits, softmax_step, mask_fn, mask_fn_none, @@ -2425,8 +2465,11 @@ def correction_loop( batch_idx, head_idx, m_block, + split_idx, + num_splits, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, + seqlen_info=seqlen, ) has_work = total_block_count > Int32(0) else: @@ -2841,6 +2884,7 @@ def epilogue_s2g( num_splits: int, SeqlenInfoCls: Callable, mma_tile_coord_v: Int32 = 0, + blocksparse_tensors: Optional[BlockSparseTensors] = None, tile_scheduler=None, ): epi_consumer_phase = Int32(0) @@ -2849,8 +2893,9 @@ def epilogue_s2g( m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + has_work = const_expr(self.use_block_sparsity or not self.is_split_kv) or n_block_min < n_block_max - if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + if has_work: if const_expr(self.is_split_kv): mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: @@ -3107,7 +3152,7 @@ def apply_score_mod( batch_idx, head_idx, softmax.softmax_scale, - self.vec_size, + self.score_vec_size, self.qk_acc_dtype, aux_tensors, fastdiv_mods, diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index 5365439ae83..793e21cfa63 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -71,7 +71,7 @@ def __init__( "Paged KV does not support irregular head dim" ) self.cluster_shape_mn = (1, 1) - assert self.arch >= Arch.sm_90 and self.arch <= Arch.sm_90a, "Only SM 9.x is supported" + assert self.arch.is_family_of(Arch.sm_90a), "Only SM 9.x is supported" def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( @@ -580,6 +580,14 @@ def kernel( mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, + mCuTotalMBlocks=( + blocksparse_tensors.cu_total_m_blocks if blocksparse_tensors is not None else None + ), + mCuBlockIdxOffsets=( + blocksparse_tensors.cu_block_idx_offsets + if blocksparse_tensors is not None + else None + ), # Don't need to pass in tile_mn because we won't access offset_padded ) AttentionMaskCls = partial( @@ -914,6 +922,7 @@ def load( batch_idx, head_idx, m_block, + seqlen, kv_producer_state, tma_load_K_fn, tma_load_V_fn, @@ -1550,7 +1559,7 @@ def apply_score_mod( batch_idx, head_idx, softmax_scale, - self.vec_size, + self.score_vec_size, self.qk_acc_dtype, aux_tensors, fastdiv_mods, diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 6b3eab32632..83d8267eb2c 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -55,6 +55,7 @@ to_cute_block_sparse_tensors, normalize_block_sparse_config, normalize_block_sparse_config_bwd, + get_block_sparse_broadcast_pattern, ) def _parse_arch_str(arch_str): @@ -258,6 +259,11 @@ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): # If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. if num_n_blocks <= 4: return 1 + # Avoid ZeroDivisionError when batch_size or seqlen_q is 0. The empty-Q + # early-exit in _flash_attn_fwd handles correctness for those shapes; this + # guard just keeps the heuristic safe if called in other contexts. + if total_mblocks == 0: + return 1 # NOTE: We should revisit this heuristic after persistence is supported for split KV. # Sometimes, it's ideal to over-schedule splits for better efficiency. @@ -477,7 +483,7 @@ def _flash_attn_fwd( elif lse is not None: _validate_tensor(lse, "lse", lse_shape, torch.float32, device) - if seqlen_k == 0: + if seqlen_k == 0 or total_q == 0: out.zero_() if lse is not None: lse.fill_(float("-inf")) @@ -502,7 +508,7 @@ def _flash_attn_fwd( ) requested_use_clc_scheduler = utils._get_use_clc_scheduler_default() - requested_disable_2cta = utils._get_disable_2cta_default() + requested_disable_2cta = utils._get_disable_2cta_default(is_fwd=True) current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) @@ -547,7 +553,7 @@ def _flash_attn_fwd( if cu_seqlens_k is None and seqused_k is None: min_seqlen_k = seqlen_k seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead - if arch // 10 == 10: + if arch // 10 in [10, 11]: q_stage = 2 if seqlen_q_packgqa > tile_m else 1 else: q_stage = 1 @@ -592,8 +598,8 @@ def _flash_attn_fwd( and (tile_m % qhead_per_kvhead == 0 or not pack_gqa) ) - # hd=256 2CTA forward uses dedicated kernel (SM100 only) - use_dedicated_hd256_kernel = arch // 10 == 10 and head_dim == 256 and head_dim_v == 256 + # hd=256 2CTA forward uses dedicated kernel (Blackwell family) + use_dedicated_hd256_kernel = arch // 10 in [10, 11] and head_dim == 256 and head_dim_v == 256 use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel if softcap is not None: @@ -621,23 +627,14 @@ def _flash_attn_fwd( is_dense_noncausal = not is_varlen and not causal and not local use_clc_scheduler = requested_use_clc_scheduler and not is_varlen_mha and not is_dense_noncausal - if mask_mod is not None: - if is_varlen: - raise NotImplementedError( - "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." - ) - if use_block_sparsity: - if is_varlen: - raise NotImplementedError( - "Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR." - ) # NB: pack_gqa requires block sparse head dim == 1 (broadcasted) - if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1: + head_dim_idx = 0 if block_sparse_tensors.mask_block_cnt.ndim == 2 else 1 + if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[head_dim_idx] != 1: pack_gqa = False - if is_split_kv: - raise NotImplementedError( - "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." + if cu_seqlens_q is not None: + assert block_sparse_tensors.cu_total_m_blocks is not None, ( + "Varlen block sparsity requires block_sparse_tensors.cu_total_m_blocks." ) # See get_broadcast_dims for why this is needed in compile key @@ -645,8 +642,6 @@ def _flash_attn_fwd( normalized_block_sparse_tensors = None q_subtile_factor = None if block_sparse_tensors is not None: - if seqlen_q is None: - raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") ( normalized_block_sparse_tensors, block_sparse_broadcast_pattern, @@ -726,6 +721,8 @@ def _flash_attn_fwd( q_descale is not None, k_descale is not None, v_descale is not None, + block_sparse_tensors is None or block_sparse_tensors.cu_total_m_blocks is None, + block_sparse_tensors is None or block_sparse_tensors.cu_block_idx_offsets is None, tile_m, tile_n, q_stage, @@ -1029,18 +1026,22 @@ def _flash_attn_fwd( window_size_left, window_size_right, learnable_sink_tensor, - sparse_tensors, - cute_aux_tensors, - current_stream, ] if arch // 10 in [10, 11]: - compile_args.insert(-3, descale_tensors_tensor) + compile_args.append(descale_tensors_tensor) + compile_args.extend([ + sparse_tensors, + cute_aux_tensors, + ]) # TODO: thread output_scale into the hd256 (BlackwellFusedMultiHeadAttentionForward) # and MLA (FlashAttentionMLAForwardSm100) kernels so fused FP8 output works there # too, then drop this special-casing and the qv/hd256 fp8-output guards above. if not use_dedicated_hd256_kernel: - compile_args.insert(-1, output_scale_tensor) - _flash_attn_fwd.compile_cache[compile_key] = cute.compile(*compile_args, options="--enable-tvm-ffi") + compile_args.append(output_scale_tensor) + compile_args.append(current_stream) + _flash_attn_fwd.compile_cache[compile_key] = cute.compile( + *compile_args, options="--enable-tvm-ffi" + ) if not is_fake_mode(): q_call, k_call, v_call = q.detach(), k.detach(), v.detach() @@ -1103,6 +1104,8 @@ def _flash_attn_fwd( normalized_block_sparse_tensors.mask_block_idx, normalized_block_sparse_tensors.full_block_cnt, normalized_block_sparse_tensors.full_block_idx, + normalized_block_sparse_tensors.cu_total_m_blocks, + normalized_block_sparse_tensors.cu_block_idx_offsets, normalized_block_sparse_tensors.dq_write_order, normalized_block_sparse_tensors.dq_write_order_full, ) @@ -1397,7 +1400,7 @@ def _flash_attn_bwd( cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1 use_2cta_instrs = cluster_size==2 - use_dedicated_hd256_kernel = arch // 10 == 10 and head_dim == 256 and head_dim_v == 256 + use_dedicated_hd256_kernel = arch // 10 in [10, 11] and head_dim == 256 and head_dim_v == 256 use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ @@ -1487,7 +1490,7 @@ def _flash_attn_bwd( ) score_mod = utils.create_softcap_scoremod(softcap) score_mod_bwd = utils.create_softcap_scoremod_bwd(softcap) - elif score_mod is not None: + if score_mod is not None: assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided" assert cu_seqlens_q is None and cu_seqlens_k is None, ( "varlen + score_mod not supported in bwd yet" @@ -1929,6 +1932,8 @@ def _flash_attn_bwd( normalized_block_sparse_tensors.mask_block_idx, normalized_block_sparse_tensors.full_block_cnt, normalized_block_sparse_tensors.full_block_idx, + normalized_block_sparse_tensors.cu_total_m_blocks, + normalized_block_sparse_tensors.cu_block_idx_offsets, normalized_block_sparse_tensors.dq_write_order, normalized_block_sparse_tensors.dq_write_order_full, ) @@ -2099,6 +2104,8 @@ def forward( deterministic: bool = False, score_mod: Optional[Callable] = None, score_mod_bwd: Optional[Callable] = None, + mask_mod: Optional[Callable] = None, + block_sparse_tensors: Optional[list] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, out: Optional[torch.Tensor] = None, @@ -2126,6 +2133,8 @@ def forward( num_splits=num_splits, pack_gqa=pack_gqa, score_mod=score_mod, + mask_mod=mask_mod, + block_sparse_tensors=block_sparse_tensors, aux_tensors=aux_tensors, return_lse=return_lse, gather_kv_indices=gather_kv_indices, @@ -2267,6 +2276,8 @@ def flash_attn_varlen_func( deterministic: bool = False, score_mod: Optional[Callable] = None, score_mod_bwd: Optional[Callable] = None, + mask_mod: Optional[Callable] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, out: Optional[torch.Tensor] = None, @@ -2310,6 +2321,8 @@ def flash_attn_varlen_func( deterministic, score_mod, score_mod_bwd, + mask_mod, + block_sparse_tensors, aux_tensors, return_lse, out, diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 9c171ba9865..47e1290fdd8 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -102,6 +102,27 @@ def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32: return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep) +@cute.jit +def apply_packed_mask_chunk( + X: cute.Tensor, + chunk_idx: cutlass.Constexpr[int], + mask: Uint32, +) -> None: + """Apply one 32-bit keep mask to one 32-column chunk. + + The one-iteration chunk loop keeps the same lowering pattern as mask_r2p_lambda. + """ + ncol = const_expr(cute.size(X.shape)) + col_base = chunk_idx * MASK_R2P_CHUNK_SIZE + for s in cutlass.range_constexpr(1): + for i in cutlass.range_constexpr( + min(MASK_R2P_CHUNK_SIZE, ncol - col_base - s * MASK_R2P_CHUNK_SIZE) + ): + in_bound = cutlass.Boolean(mask & (Uint32(1) << i)) + c = col_base + s * MASK_R2P_CHUNK_SIZE + i + X[c] = X[c] if in_bound else -Float32.inf + + @dataclass(frozen=True) class AttentionMask: tile_m: cutlass.Constexpr[int] @@ -369,6 +390,192 @@ def mask_gen_fn(s: int) -> Uint32: else acc_S_mn[r, c] ) + @cute.jit + def apply_mask_mod_sm100_scalar( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + m_block: Int32, + n_block: Int32, + mask_seqlen: cutlass.Constexpr[bool], + mask_mod: cutlass.Constexpr[Callable], + batch_idx: Int32, + head_idx: Int32, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + head_divmod=None, + check_q_boundary: bool = False, + ) -> None: + """Apply a scalar FlexAttention mask_mod to an SM100 accumulator fragment. + + Each accumulator lane calls mask_mod once with logical (batch, head, q, kv) + indices. Pack-GQA rows are converted back to logical q/head indices before + the call. When aux tensors are present, indices are wrapped with fastdiv so + mask_mod never reads outside the per-example auxiliary storage. + """ + has_fastdiv = const_expr( + fastdiv_mods is not None and fastdiv_mods[0] is not None and fastdiv_mods[1] is not None + ) + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) + ncol = const_expr(cute.size(tScS_t2r.shape)) + + for i in cutlass.range_constexpr(ncol): + row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1] + col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] + global_row = row_coord + m_block * self.tile_m + global_col = col_coord + n_block * self.tile_n + + if const_expr(self.qhead_per_kvhead_packgqa != 1): + assert head_divmod is not None + mask_row, head_offset = divmod(global_row, head_divmod) + head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset + else: + head_idx_for_mod = head_idx + mask_row = global_row + + mask_row_for_mod = mask_row + if const_expr(has_fastdiv and aux_tensors is not None): + if check_q_boundary: + _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) + global_col_for_mod = global_col + if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None): + _, global_col_for_mod = divmod(global_col, fastdiv_mods[1]) + + head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32) + mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32) + kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32) + mask_value = mask_mod( + batch_idx_ssa, + head_idx_ssa, + mask_row_ssa, + kv_idx_ssa, + self.seqlen_info, + aux_tensors, + ) + cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) + acc_S[i] = acc_S[i] if cond else -Float32.inf + if const_expr(mask_seqlen): + acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i] + if check_q_boundary: + acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i] + + @cute.jit + def apply_mask_mod_sm100_vector( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + m_block: Int32, + n_block: Int32, + mask_seqlen: cutlass.Constexpr[bool], + mask_mod: cutlass.Constexpr[Callable], + batch_idx: Int32, + head_idx: Int32, + vec_size: cutlass.Constexpr[int], + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + head_divmod=None, + check_q_boundary: bool = False, + ) -> None: + """Apply a vectorized FlexAttention mask_mod to an SM100 fragment. + + mask_mod receives vec_size adjacent KV indices for one logical q row and + returns bit-packed Uint32 keep masks. Low bits correspond to lower KV + indices. The packed masks are combined with sequence-boundary checks, then + applied in 32-column chunks so the final masking lowers to R2P. + """ + has_fastdiv = const_expr( + fastdiv_mods is not None and fastdiv_mods[0] is not None and fastdiv_mods[1] is not None + ) + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) + ncol = const_expr(cute.size(tScS_t2r.shape)) + mask_vals_per_apply = const_expr(max(1, vec_size // 32)) + calls_per_apply = const_expr(max(1, 32 // vec_size)) + n_calls = const_expr(cute.ceil_div(ncol, vec_size)) + mask_vals = cute.make_rmem_tensor(mask_vals_per_apply, dtype=cutlass.Uint32) + + # Accumulate enough vector mask_mod calls to produce 32-bit chunks that + # apply_packed_mask_chunk can lower to R2P. + for s in cutlass.range_constexpr(n_calls): + if const_expr(s % calls_per_apply == 0): + for c in cutlass.range_constexpr(mask_vals_per_apply): + mask_vals[c] = cutlass.Uint32(0) + i = s * vec_size + row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1] + col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] + global_row = row_coord + m_block * self.tile_m + global_col = col_coord + n_block * self.tile_n + if const_expr(self.qhead_per_kvhead_packgqa != 1): + assert head_divmod is not None + mask_row, head_offset = divmod(global_row, head_divmod) + head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset + else: + head_idx_for_mod = head_idx + mask_row = global_row + mask_row_for_mod = mask_row + if const_expr(has_fastdiv and aux_tensors is not None): + if check_q_boundary: + _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) + + head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32).broadcast_to( + (vec_size,) + ) + mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32).broadcast_to( + (vec_size,) + ) + batch_idx_ssa_call = batch_idx_ssa.broadcast_to((vec_size,)) + kv_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) + + # Build the per-lane KV indices for this vectorized mask_mod call. + for j in cutlass.range_constexpr(min(vec_size, ncol - i)): + col_j_coord = tScS_t2r[i + j][1] if not self.swap_AB else tScS_t2r[i + j][0] + col_j_global = col_j_coord + n_block * self.tile_n + col_j_for_mod = col_j_global + if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None): + _, col_j_for_mod = divmod(col_j_global, fastdiv_mods[1]) + kv_idx_vec[j] = col_j_for_mod + kv_idx_ssa = kv_idx_vec.load() + + # mask_value is already bit-packed by the vectorized mask_mod. + mask_value = mask_mod( + batch_idx_ssa_call, + head_idx_ssa, + mask_row_ssa, + kv_idx_ssa, + self.seqlen_info, + aux_tensors, + ) + + # For vec_size < 32, multiple mask_mod calls fill one R2P chunk. + bit_offset = const_expr((s % calls_per_apply) * vec_size) + seqlen_thresh_call = ( + self.seqlen_k - global_col if const_expr(mask_seqlen) else cutlass.Int32(0) + ) + q_in_bounds = mask_row < self.seqlen_q if check_q_boundary else cutlass.Boolean(True) + for c in cutlass.range_constexpr(mask_vals_per_apply): + mask_val = mask_value[c] + if const_expr(vec_size < 32): + lane_keep = utils.shr_u32( + cutlass.Uint32(0xFFFFFFFF), + cutlass.Uint32(32 - vec_size), + ) + mask_val = mask_val & lane_keep + if const_expr(mask_seqlen): + mask_val = mask_val & r2p_bitmask_below(seqlen_thresh_call, c) + if check_q_boundary: + mask_val = mask_val if q_in_bounds else cutlass.Uint32(0) + mask_vals[c] = mask_vals[c] | (mask_val << bit_offset) + + # Apply only when the 32-bit chunk is complete, or at the tile tail. + is_last_in_apply = const_expr(s % calls_per_apply == calls_per_apply - 1) + is_last_overall = const_expr(s == n_calls - 1) + if const_expr(is_last_in_apply or is_last_overall): + apply_idx = s // calls_per_apply + for c in cutlass.range_constexpr(mask_vals_per_apply): + chunk_idx = apply_idx * mask_vals_per_apply + c + # Skip packed chunks that start past the accumulator fragment. + if const_expr(chunk_idx * 32 < ncol): + apply_packed_mask_chunk(acc_S, chunk_idx, mask_vals[c]) + @cute.jit def apply_mask_sm100( self, @@ -386,6 +593,7 @@ def apply_mask_sm100( aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), head_divmod=None, + vec_size: cutlass.Constexpr[int] = 1, check_q_boundary: bool = False, r2p: bool = True, rBitmask: Optional[cute.Tensor] = None, @@ -429,54 +637,43 @@ def apply_mask_sm100( ) elif const_expr(not mask_causal and not mask_local and mask_mod is not None): - # Block sparse case w/ mask_mod - has_fastdiv = const_expr( - fastdiv_mods is not None - and fastdiv_mods[0] is not None - and fastdiv_mods[1] is not None + # FlexAttention mask_mod vectorization is gated on `mask_mod.__vec_size__`. + # vec_size == 1 returns a scalar Boolean. vec_size > 1 returns packed + # Uint32 mask fragments: one word per 32 evaluated columns. + assert vec_size % 32 == 0 or 32 % vec_size == 0, ( + "vec_size must divide 32 or be a multiple of 32" ) - batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) - - ncol = const_expr(cute.size(tScS_t2r.shape)) - for i in cutlass.range_constexpr(ncol): - row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1] - col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] - global_row = row_coord + m_block * self.tile_m - global_col = col_coord + n_block * self.tile_n - - if const_expr(self.qhead_per_kvhead_packgqa != 1): - assert head_divmod is not None - mask_row, head_offset = divmod(global_row, head_divmod) - head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset - else: - head_idx_for_mod = head_idx - mask_row = global_row - - mask_row_for_mod = mask_row - if const_expr(has_fastdiv and aux_tensors is not None): - if check_q_boundary: - _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) - global_col_for_mod = global_col - if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None): - _, global_col_for_mod = divmod(global_col, fastdiv_mods[1]) - - head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32) - mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32) - kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32) - mask_value = mask_mod( - batch_idx_ssa, - head_idx_ssa, - mask_row_ssa, - kv_idx_ssa, - self.seqlen_info, + if const_expr(vec_size == 1): + self.apply_mask_mod_sm100_scalar( + acc_S, + tScS_t2r, + m_block, + n_block, + mask_seqlen, + mask_mod, + batch_idx, + head_idx, aux_tensors, + fastdiv_mods, + head_divmod, + check_q_boundary, + ) + else: + self.apply_mask_mod_sm100_vector( + acc_S, + tScS_t2r, + m_block, + n_block, + mask_seqlen, + mask_mod, + batch_idx, + head_idx, + vec_size, + aux_tensors, + fastdiv_mods, + head_divmod, + check_q_boundary, ) - cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) - acc_S[i] = acc_S[i] if cond else -Float32.inf - if const_expr(mask_seqlen): - acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i] - if check_q_boundary: - acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i] else: # Causal or local causal_row_offset = self.seqlen_k - n_block * self.tile_n - self.seqlen_q @@ -871,6 +1068,9 @@ def get_trip_start_count_via_block_info( padded_offset_k=Int32(0), seqlen_q=seqlen_q, seqlen_k=seqlen_k, + m_block_offset=Int32(0), + block_idx_offset=Int32(0), + num_n_blocks=cute.ceil_div(seqlen_k, tile_shape[1]), has_cu_seqlens_q=False, has_cu_seqlens_k=False, has_seqused_q=False, @@ -911,6 +1111,9 @@ def get_trip_mask_bounds_via_block_info( padded_offset_k=Int32(0), seqlen_q=seqlen_q, seqlen_k=seqlen_k, + m_block_offset=Int32(0), + block_idx_offset=Int32(0), + num_n_blocks=cute.ceil_div(seqlen_k, tile_shape[1]), has_cu_seqlens_q=False, has_cu_seqlens_k=False, has_seqused_q=False, diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index 8e8fdf69ddc..c8ba5672664 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -71,6 +71,9 @@ class SeqlenInfoQK: padded_offset_k: Int32 seqlen_q: Int32 seqlen_k: Int32 + m_block_offset: Int32 + block_idx_offset: Int32 + num_n_blocks: Int32 has_cu_seqlens_q: cutlass.Constexpr[bool] has_cu_seqlens_k: cutlass.Constexpr[bool] has_seqused_q: cutlass.Constexpr[bool] @@ -85,6 +88,8 @@ def create( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, + mCuBlockIdxOffsets: Optional[cute.Tensor] = None, tile_m: cutlass.Constexpr[Int32] = 128, tile_n: cutlass.Constexpr[Int32] = 128, ): @@ -116,6 +121,13 @@ def create( if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx + 1] - offset_k ) + m_block_offset = 0 if const_expr(mCuTotalMBlocks is None) else mCuTotalMBlocks[batch_idx] + num_n_blocks = (seqlen_k + tile_n - 1) // tile_n + block_idx_offset = ( + mCuBlockIdxOffsets[batch_idx] + if const_expr(mCuBlockIdxOffsets is not None) + else m_block_offset * num_n_blocks + ) return SeqlenInfoQK( offset_q, offset_k, @@ -123,6 +135,9 @@ def create( padded_offset_k, seqlen_q, seqlen_k, + m_block_offset, + block_idx_offset, + num_n_blocks, has_cu_seqlens_q=mCuSeqlensQ is not None, has_cu_seqlens_k=mCuSeqlensK is not None, has_seqused_q=mSeqUsedQ is not None, diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py index c07e3e94176..ecda0e273ad 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py @@ -21,6 +21,7 @@ from flash_attn.cute.sm100_hd256_2cta_fmha_backward_dkdvkernel import ( BlackwellFusedMultiHeadAttentionBackwardDKDVKernel, ) +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned def _as_bshkrd_tensor( @@ -251,6 +252,8 @@ def __call__( else: b = Q.shape[0] + Q, K, V, dQ, dK, dV, dO = [assume_tensor_aligned(t) for t in (Q, K, V, dQ, dK, dV, dO)] + Q = _as_bshkrd_tensor(Q, h_k, h_r, varlen) K = _as_bshkrd_tensor(K, h_k, 1, varlen) V = _as_bshkrd_tensor(V, h_k, 1, varlen) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py index 7a8cdeede6a..885ae336f5f 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py @@ -32,6 +32,7 @@ Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams, ) +import flash_attn.cute.copy_utils as fa_copy_utils LAYOUT_RANK_CONSTANT = 3 @@ -244,12 +245,8 @@ def __call__( cute.assume(Q.stride[1], divby=64), Q.stride[4], ( - (Q.shape[4], Q.shape[4] * Q.shape[3]), - ( - 0 - if varlen - else cute.assume(Q.shape[1] * Q.shape[4] * h_r * h_k, divby=64) - ), + (Q.stride[3], Q.stride[2]), + 0 if cumulative_s_q is not None else cute.assume(Q.stride[0], divby=64), ), ), ), @@ -263,8 +260,8 @@ def __call__( cute.assume(K.stride[1], divby=64), K.stride[4], ( - (0, K.shape[4]), - (0 if varlen else cute.assume(K.shape[1] * K.shape[4] * 1 * h_k, divby=64)), + (0, K.stride[2]), + 0 if cumulative_s_k is not None else cute.assume(K.stride[0], divby=64), ), ), ), @@ -278,8 +275,8 @@ def __call__( cute.assume(V.stride[1], divby=64), V.stride[4], ( - (0, V.shape[4]), - (0 if varlen else cute.assume(V.shape[1] * V.shape[4] * 1 * h_k, divby=64)), + (0, V.stride[2]), + 0 if cumulative_s_k is not None else cute.assume(V.stride[0], divby=64), ), ), ), @@ -296,10 +293,49 @@ def __call__( ), ), ) - dK = cute.make_tensor(dK.iterator, K.layout) - dV = cute.make_tensor(dV.iterator, V.layout) + dK = cute.make_tensor( + dK.iterator, + cute.make_layout( + (dK.shape[1], dK.shape[4], hb), + stride=( + cute.assume(dK.stride[1], divby=64), + dK.stride[4], + ( + (0, dK.stride[2]), + 0 if cumulative_s_k is not None else cute.assume(dK.stride[0], divby=64), + ), + ), + ), + ) + dV = cute.make_tensor( + dV.iterator, + cute.make_layout( + (dV.shape[1], dV.shape[4], hb), + stride=( + cute.assume(dV.stride[1], divby=64), + dV.stride[4], + ( + (0, dV.stride[2]), + 0 if cumulative_s_k is not None else cute.assume(dV.stride[0], divby=64), + ), + ), + ), + ) # (s, d, ((h_r, h_k), b)) - dO = cute.make_tensor(dO.iterator, Q.layout) + dO = cute.make_tensor( + dO.iterator, + cute.make_layout( + (dO.shape[1], dO.shape[4], hb), + stride=( + cute.assume(dO.stride[1], divby=64), + dO.stride[4], + ( + (dO.stride[3], dO.stride[2]), + 0 if cumulative_s_q is not None else cute.assume(dO.stride[0], divby=64), + ), + ), + ), + ) # (s, d, ((h_r, h_k), b)) -> (d, s, ((h_r, h_k), b)) dOT = cute.make_tensor( @@ -2776,13 +2812,11 @@ def epilogue_clear( dK.iterator + mdK_offset, cute.make_layout((K, self.tile_shape_dQ_K, HB), stride=dK.stride), ) - gdK = cute.local_tile( - mdK, (self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1]), (None, None, None) - ) + gdK = cute.local_tile(mdK, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) gdK = gdK[None, None, blk_coord_k, 0, blk_coord_batch] cdK = cute.domain_offset( (blk_coord_k * self.tile_shape_K, 0), - cute.make_identity_tensor((self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1])), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), ) mdV_offset = cute.assume(blk_offset[1] * dV.stride[0], divby=64) @@ -2790,24 +2824,41 @@ def epilogue_clear( dV.iterator + mdV_offset, cute.make_layout((K, self.tile_shape_dV_dO, HB), stride=dV.stride), ) - gdV = cute.local_tile( - mdV, (self.PdO_mma_tiler[0], self.PdO_mma_tiler[1]), (None, None, None) - ) + gdV = cute.local_tile(mdV, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) gdV = gdV[None, None, blk_coord_k, 0, blk_coord_batch] cdV = cute.domain_offset( (blk_coord_k * self.tile_shape_K, 0), - cute.make_identity_tensor((self.PdO_mma_tiler[0], self.PdO_mma_tiler[1])), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), + ) + + num_zero_epi_threads = 256 + + tiled_copy_r2g = fa_copy_utils.tiled_copy_2d( + dK.element_type, self.cta_tiler[2], num_zero_epi_threads ) - for i in cutlass.range(tidx * 8, cute.size(gdK), block_dim_x * 8): - if cute.elem_less(cdK[i], cute.select(problem_shape, mode=[1, 2])): - gdK_i = cute.make_tensor(gdK.iterator + cute.assume(i, divby=8), (8)) - gdK_i.fill(0) + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + + tRG_gdK = thr_copy_r2g.partition_D(gdK) + tRG_cdK = thr_copy_r2g.partition_D(cdK) + tRG_gdV = thr_copy_r2g.partition_D(gdV) + tRG_cdV = thr_copy_r2g.partition_D(cdV) + + zero_frg = cute.make_rmem_tensor_like(tRG_gdK[None, 0, None]) + zero_frg.fill(dK.element_type(0.0)) + + # check we don't need zero fragment duplication + V_frg_size = cute.size(tRG_gdV[None, 0, None]) + assert cute.size(zero_frg) == V_frg_size + + if tidx < num_zero_epi_threads: + for n in cutlass.range(cute.size(tRG_gdK.shape[1]), unroll_full=True): + if cute.elem_less(tRG_cdK[0, n, 0][0], problem_shape[1]): + cute.copy(tiled_copy_r2g, zero_frg, tRG_gdK[None, n, None]) - for i in cutlass.range(tidx * 8, cute.size(gdV), block_dim_x * 8): - if cute.elem_less(cdV[i], cute.select(problem_shape, mode=[1, 2])): - gdV_i = cute.make_tensor(gdV.iterator + cute.assume(i, divby=8), (8)) - gdV_i.fill(0) + for n in cutlass.range(cute.size(tRG_gdV.shape[1]), unroll_full=True): + if cute.elem_less(tRG_cdV[0, n, 0][0], problem_shape[1]): + cute.copy(tiled_copy_r2g, zero_frg, tRG_gdV[None, n, None]) @cute.jit def epilogue( diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py index b25ca48f007..25d6a91de70 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py @@ -29,6 +29,7 @@ Sm100FusedMask as FusedMask, ) from flash_attn.cute.tile_scheduler import SM100_TMEM_CAPACITY_COLUMNS +import flash_attn.cute.copy_utils as fa_copy_utils class BlackwellFusedMultiHeadAttentionBackwardDQKernel: @@ -187,7 +188,6 @@ def __call__( s_q64 = Int64(s_q) s_k64 = Int64(s_k) s_lse64 = Int64(s_lse) - d64 = cute.assume(Int64(d), divby=128) h_r64 = Int64(h_r) h_k64 = Int64(h_k) b64 = Int64(b) @@ -196,39 +196,72 @@ def __call__( # `cuseqlen_*` offsets stays within the tensor domain. s_q_total = q_tensor.shape[1] if cum_seqlen_q is not None else s_q64 s_k_total = k_tensor.shape[1] if cum_seqlen_k is not None else s_k64 - stride_b_qo = h_r64 * h_k64 * s_q64 * d64 if cum_seqlen_q is None else 0 - stride_b_kv = h_k64 * s_k64 * d64 if cum_seqlen_k is None else 0 b_lse = b64 if cum_seqlen_q is None else 1 stride_b_lse = h_r64 * h_k64 * s_lse64 if cum_seqlen_q is None else 0 # (s, d, ((h_r, h_k), b)) q_layout = cute.make_layout( (s_q_total, d, ((h_r, h_k), b)), - stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + stride=( + cute.assume(q_tensor.stride[1], divby=64), + q_tensor.stride[4], + ( + (q_tensor.stride[3], q_tensor.stride[2]), + 0 if cum_seqlen_q is not None else cute.assume(q_tensor.stride[0], divby=64), + ), + ), ) q = cute.make_tensor(q_tensor.iterator, q_layout) # (s, d, ((h_r, h_k), b)) do_layout = cute.make_layout( (s_q_total, d, ((h_r, h_k), b)), - stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + stride=( + cute.assume(do_tensor.stride[1], divby=64), + do_tensor.stride[4], + ( + (do_tensor.stride[3], do_tensor.stride[2]), + 0 if cum_seqlen_q is not None else cute.assume(do_tensor.stride[0], divby=64), + ), + ), ) do = cute.make_tensor(do_tensor.iterator, do_layout) # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast k_layout = cute.make_layout( (s_k_total, d, ((h_r, h_k), b)), - stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)), + stride=( + cute.assume(k_tensor.stride[1], divby=64), + k_tensor.stride[4], + ( + (0, k_tensor.stride[2]), + 0 if cum_seqlen_k is not None else cute.assume(k_tensor.stride[0], divby=64), + ), + ), ) k = cute.make_tensor(k_tensor.iterator, k_layout) # (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast kt_layout = cute.make_layout( (d, s_k_total, ((h_r, h_k), b)), - stride=(1, d64 * h_k64, ((0, d64), stride_b_kv)), + stride=( + k_tensor.stride[4], + cute.assume(k_tensor.stride[1], divby=64), + ( + (0, k_tensor.stride[2]), + 0 if cum_seqlen_k is not None else cute.assume(k_tensor.stride[0], divby=64), + ), + ), ) kt = cute.make_tensor(k_tensor.iterator, kt_layout) # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast v_layout = cute.make_layout( (s_k_total, d, ((h_r, h_k), b)), - stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)), + stride=( + cute.assume(v_tensor.stride[1], divby=64), + v_tensor.stride[4], + ( + (0, v_tensor.stride[2]), + 0 if cum_seqlen_k is not None else cute.assume(v_tensor.stride[0], divby=64), + ), + ), ) v = cute.make_tensor(v_tensor.iterator, v_layout) # (s, ((h_r, h_k), b)) @@ -242,7 +275,14 @@ def __call__( # (s, d, ((h_r, h_k), b)) dq_layout = cute.make_layout( (s_q_total, d, ((h_r, h_k), b)), - stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + stride=( + cute.assume(dq_tensor.stride[1], divby=64), + dq_tensor.stride[4], + ( + (dq_tensor.stride[3], dq_tensor.stride[2]), + 0 if cum_seqlen_q is not None else cute.assume(dq_tensor.stride[0], divby=64), + ), + ), ) dq = cute.make_tensor(dq_tensor.iterator, dq_layout) @@ -885,36 +925,45 @@ def kernel( curr_block_coord[1], curr_block_coord[2], ) - continue_cond = False batch_coord = curr_block_coord[2][1] seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] cuseqlen_q = Int32(0) cuseqlen_k = Int32(0) - block_offset = ( - Int32(0), - Int32(0), - Int32(0), - ((Int32(0), Int32(0)), Int32(0)), - ) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k + + if has_work: block_offset = ( cuseqlen_q, cuseqlen_k, Int32(0), ((Int32(0), Int32(0)), Int32(0)), ) - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( - self.qk_mma_tiler[0], - mma_block_coord[0], - seqlen_q, - ) - if not continue_cond: mQ_qdl_ = cute.domain_offset(cute.select(block_offset, mode=[0, 2, 3]), mQ_qdl) mK_kdl_ = cute.domain_offset(cute.select(block_offset, mode=[1, 2, 3]), mK_kdl) mdO_qdl_ = cute.domain_offset( @@ -1018,18 +1067,6 @@ def kernel( # ((atom_v, rest_v), RestN, RestK) tKTgKT = tKgK_dkl[None, None, None, mma_block_coord[2]] - seqlen_kv_loop_start, seqlen_kv_loop_steps = ( - FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) - ) # LSE lse_handle = load_lse_producer.acquire_and_advance() # 32 threads loading 128 values of 32b each @@ -1158,6 +1195,9 @@ def kernel( if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx mma_block_coord = ( @@ -1165,41 +1205,37 @@ def kernel( curr_block_coord[1], curr_block_coord[2], ) - continue_cond = False seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] batch_coord = curr_block_coord[2][1] + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) - - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k - - seqlen_kv_loop_start, seqlen_kv_loop_steps = ( - FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 + if has_work: # dq_handle = mma_dq_producer.acquire_and_advance() load_q_releaser = load_q_consumer.clone() load_do_releaser = load_do_consumer.clone() @@ -1797,33 +1833,35 @@ def kernel( curr_block_coord[2], ) batch_coord = curr_block_coord[2][1] - continue_cond = False seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] cuseqlen_q = Int32(0) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + is_valid_k = trip_count > 0 + has_work = is_valid_q and is_valid_k - start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) + if has_work: end_count = start_count + trip_count if cutlass.const_expr(self.use_semantic_trip_range): n_block_min_causal_local_mask, n_block_min_before_local_mask = ( @@ -1893,6 +1931,7 @@ def kernel( ) lse_handle.release() sum_odo_handle.release() + work_tile = tile_sched.advance_to_next_work() ds_mma_producer.tail() @@ -1913,61 +1952,75 @@ def kernel( # cute.printf("batch_coord={}", batch_coord) seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] - continue_cond = False cuseqlen_q = Int32(0) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k + + mdQ_qdl_eff = mdQ_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + block_offset_dQ = (cuseqlen_q,) + (None,) * 2 + mdQ_qdl_eff = cute.domain_offset(block_offset_dQ, mdQ_qdl) - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + # (bM, bN, loopM, loopN, loopL) + gdQ_qdl = cute.flat_divide( + mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + cdQ_qdl = cute.flat_divide( + cute.make_identity_tensor(mdQ_qdl_eff.shape), + cute.select(self.dsk_block_tiler, mode=[0, 1]), + ) - mdQ_qdl_eff = mdQ_qdl - if cutlass.const_expr(cum_seqlen_q is not None): - block_offset_dQ = ( - cuseqlen_q, - Int32(0), - Int32(0), - ((Int32(0), Int32(0)), Int32(0)), - ) - mdQ_qdl_eff = cute.domain_offset( - cute.select(block_offset_dQ, mode=[0, 2, 3]), mdQ_qdl - ) + gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + gdQ_tma_staged = gdQ_staged - # (bM, bN, loopM, loopN, loopL) - gdQ_qdl = cute.flat_divide( - mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) + if cutlass.const_expr(not varlen): + gdQ_tma_qdl = cute.flat_divide( + mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) ) - cdQ_qdl = cute.flat_divide( - cute.make_identity_tensor(mdQ_qdl_eff.shape), - cute.select(self.dsk_block_tiler, mode=[0, 1]), - ) - - gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] - cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] - gdQ_tma_staged = gdQ_staged - if cutlass.const_expr(not varlen): - gdQ_tma_qdl = cute.flat_divide( - mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) - ) - gdQ_tma_staged = gdQ_tma_qdl[ - None, None, curr_block_coord[0], None, curr_block_coord[2] - ] + gdQ_tma_staged = gdQ_tma_qdl[ + None, None, curr_block_coord[0], None, curr_block_coord[2] + ] + if has_work: # dQ TMEM to GMEM mma_dq_consumer = self.dQ_epilogue( - (seqlen_q, cuseqlen_q, mQ_qdl.shape[0], batch_coord), + seqlen_q, (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged), self.epi_tile, (tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen), ) + else: + self.dQ_epilogue_write_zero( + seqlen_q, + gdQ_staged, + cdQ_staged, + ) + work_tile = tile_sched.advance_to_next_work() # NOTE: tmem.free() moved to kernel end to enable cluster-wide sync @@ -2142,12 +2195,11 @@ def compute_step( @cute.jit def dQ_epilogue( self, - value_args: Tuple, + seqlen_q: int, dq_args: Tuple, epi_tile: cute.Tile, tma_args: Tuple, ) -> Tuple[pipeline.PipelineConsumer, pipeline.PipelineProducer]: - seqlen_q, cuseqlen_q, total_q, batch_coord = value_args (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged) = dq_args tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen = tma_args dq_handle = mma_dq_consumer.wait_and_advance() @@ -2235,3 +2287,31 @@ def dQ_epilogue( cute.autovec_copy(tSMrdQ, tTMEM_LOADgdQ_i) dq_handle.release() return mma_dq_consumer + + @cute.jit + def dQ_epilogue_write_zero( + self, + seqlen_q, + gdQ_staged, + cdQ_staged, + ): + num_epi_threads = self.threads_per_warp * len(self.epilogue_warp_ids) + tidx = cute.arch.thread_idx()[0] % num_epi_threads + + tiled_copy_r2g = fa_copy_utils.tiled_copy_2d( + self.dq_dtype, cute.size(gdQ_staged.shape[1]), num_epi_threads + ) + + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + tdQgdQ_staged = thr_copy_r2g.partition_D(gdQ_staged) + tdQcdQ_staged = thr_copy_r2g.partition_D(cdQ_staged) + + tdQrdQ = cute.make_rmem_tensor_like(tdQgdQ_staged[None, 0, None, 0]) + tdQrdQ.fill(self.dq_dtype(0.0)) + + for iter in cutlass.range(self.iterations_dsk, unroll_full=True): + tdQgdQ = tdQgdQ_staged[None, None, None, iter] + tdQcdQ = tdQcdQ_staged[None, None, None, iter] + for m in cutlass.range(cute.size(tdQgdQ.shape[1]), unroll_full=True): + if cute.elem_less(tdQcdQ[0, m, 0][0], seqlen_q): + cute.copy(tiled_copy_r2g, tdQrdQ, tdQgdQ[None, m, None]) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py index 28087125f47..379cebc1905 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py @@ -1030,6 +1030,9 @@ def kernel( if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx mma_block_coord = ( @@ -1071,10 +1074,6 @@ def kernel( ) seqlen_kv_loop_end = seqlen_kv_loop_start + seqlen_kv_loop_steps - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 load_q_releaser = load_q_consumer.clone() pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) if seqlen_kv_loop_steps > 1: @@ -1323,6 +1322,11 @@ def kernel( window_size_right, ) end_count = start_count + trip_count + # require at least one softmax iteration for zero trip_count case; + # rely on masking this iteration for correctness + if end_count <= start_count: + start_count = 0 + end_count = 1 if cutlass.const_expr(self.use_semantic_trip_range): n_block_min_causal_local_mask, n_block_min_before_local_mask = ( FusedMask.get_trip_mask_bounds_via_block_info( @@ -1349,6 +1353,7 @@ def kernel( need_apply_mask = ( step >= n_block_min_causal_local_mask or step < n_block_min_before_local_mask + or step == end_count - 1 ) else: # Residual path only needs seqlen masking on the last K tile. @@ -1797,7 +1802,8 @@ def correction_epilog( row_sum = sSum[thread_idx] cute.arch.fence_view_async_shared() sum_handle.release() - scale = scale_output / row_sum + row_sum_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + scale = scale_output / row_sum if not row_sum_is_zero_or_nan else 0.0 o_handle = mma_o_consumer.wait_and_advance() for iter in cutlass.range(self.iterations_pv): gO = gO_staged[None, None, iter] @@ -1855,6 +1861,7 @@ def store_sum_max( sSum[thread_idx] = row_sum cute.arch.fence_view_async_shared() sum_handle.commit() + row_sum_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum if cutlass.const_expr(mLSE is not None): q_idx = current_block_coord[0] * self.cta_tiler[0] + tidx @@ -1863,7 +1870,11 @@ def store_sum_max( if cutlass.const_expr(cum_seqlen_q is not None) else current_block_coord[2] ) - lse_value = scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + lse_value = ( + scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + if not row_sum_is_zero_or_nan + else -Float32.inf + ) if cute.elem_less(q_idx, seqlen_q): global_q_idx = ( q_idx + cuseqlen_q if cutlass.const_expr(cum_seqlen_q is not None) else q_idx diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index b9dc4f5c112..3daffeeff18 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -86,8 +86,11 @@ def _get_use_clc_scheduler_default() -> bool: return _fa_clc_enabled -def _get_disable_2cta_default() -> bool: - return _fa_disable_2cta_enabled or _fa_disable_2cta_cuda12 +def _get_disable_2cta_default(is_fwd: bool = False) -> bool: + if is_fwd: + return _fa_disable_2cta_enabled or _fa_disable_2cta_cuda12 + else: + return _fa_disable_2cta_enabled def _compute_base_hash(func: Callable) -> str: @@ -141,7 +144,7 @@ def hash_callable( hasher = hashlib.sha256(base_hash.encode()) - for attr, val in zip(_MIXER_ATTRS, mixer_values): + for attr, val in zip(mixer_attrs, mixer_values): hasher.update(f"{attr}={val!r}".encode()) return hasher.hexdigest() @@ -946,3 +949,20 @@ def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: def ssa_to_scalar(val): """Could inline but nice for reflecting the above api""" return val[0] + + +@cute.jit +def get_batch_from_cu_tensor(idx: Int32, cu_tensor: cute.Tensor) -> Int32: + """Binary search to determine batch from packed index in a cumulative tensor""" + batch_size = cute.size(cu_tensor) - 1 + lo = Int32(0) + hi = batch_size + + while lo < hi: + mid = (lo + hi) // 2 + if cu_tensor[mid + 1] <= idx: + lo = mid + 1 + else: + hi = mid + + return lo diff --git a/hopper/setup.py b/hopper/setup.py index 887b6339023..bc484bc6f2a 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -714,11 +714,12 @@ def run(self): else { "bdist_wheel": CachedWheelsCommand, }, - python_requires=">=3.8", + python_requires=">=3.10", install_requires=[ "torch", "einops", "packaging", "ninja", ], + options={"bdist_wheel": {"py_limited_api": "cp310"}}, ) diff --git a/tests/cute/benchmark_block_sparsity.py b/tests/cute/benchmark_block_sparsity.py index ed6bfad2daa..c7f3229d464 100644 --- a/tests/cute/benchmark_block_sparsity.py +++ b/tests/cute/benchmark_block_sparsity.py @@ -39,7 +39,7 @@ class BenchmarkConfig: mask_name: str tile_m: int = 128 tile_n: int = 128 - use_fast_sampling: bool = False + use_fast_sampling: bool = True aux_tensors_cute: Optional[list] = None @@ -162,7 +162,7 @@ def benchmark_cute_block_sparsity( blocksparse_tensors, config.seqlen_q, config.seqlen_k, - config.aux_tensors_cute, + aux_tensors=config.aux_tensors_cute, ) def generate_tensors(): @@ -336,7 +336,7 @@ def main(): mask_name=config.mask_name, tile_m=config.tile_m, tile_n=config.tile_n, - use_fast_sampling=False, + use_fast_sampling=True, aux_tensors_cute=[doc_ids_cute], ) ) diff --git a/tests/cute/mask_mod_definitions.py b/tests/cute/mask_mod_definitions.py index 0820c6f5271..71cf0b9b7a5 100644 --- a/tests/cute/mask_mod_definitions.py +++ b/tests/cute/mask_mod_definitions.py @@ -159,7 +159,6 @@ def cute_document_mask( return m_doc == n_doc -@fast_sampling @cute.jit def cute_ima_mask( batch: cute.TensorSSA, @@ -180,7 +179,99 @@ def cute_ima_mask( # n_idx_global = n_idx + seqlen_info.offset_k # ============================================================================= -# TODO: Add varlen mask implementations here + +@fast_sampling +@cute.jit +def cute_global_packed_doc_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + """Document mask using globally-indexed packed 1D doc ID tensors. + + aux_tensors[0]: doc_ids_q (total_q,) int32 — packed doc IDs for Q tokens + aux_tensors[1]: doc_ids_k (total_k,) int32 — packed doc IDs for K tokens + Mask: doc_ids_q[m_global] == doc_ids_k[n_global] + """ + doc_ids_q = aux_tensors[0] + doc_ids_k = aux_tensors[1] + + offset_q = seqlen_info.offset_q + m_global = m_idx + offset_q + m_frag = cute.make_fragment(1, cutlass.Int32) + m_frag.store(m_global) + m_doc_frag = cute.make_fragment(1, cutlass.Int32) + m_doc_frag[0] = doc_ids_q[m_frag[0]] + + offset_k = seqlen_info.offset_k + n_global = n_idx + offset_k + n_frag = cute.make_fragment(1, cutlass.Int32) + n_frag.store(n_global) + n_doc_frag = cute.make_fragment(1, cutlass.Int32) + n_doc_frag[0] = doc_ids_k[n_frag[0]] + + m_doc = m_doc_frag.load() + n_doc = n_doc_frag.load() + return m_doc == n_doc + + +@fast_sampling +@cute.jit +def cute_global_ima_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + """IMA-style mask using globally-indexed threshold tensor. + + aux_tensors[0]: thresholds (total_k,) int32 — per-global-kv-position threshold + Mask: n_idx >= thresholds[n_global] (local n_idx >= globally-indexed threshold) + """ + thresholds = aux_tensors[0] + + offset_k = seqlen_info.offset_k + n_global = n_idx + offset_k + n_frag = cute.make_fragment(1, cutlass.Int32) + n_frag.store(n_global) + val_frag = cute.make_fragment(1, cutlass.Int32) + val_frag[0] = thresholds[n_frag[0]] + threshold = val_frag.load() + + return n_idx >= threshold + + +@fast_sampling +@cute.jit +def cute_global_causal_window_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + """Causal window mask with per-token window sizes indexed globally. + + aux_tensors[0]: windows (total_q,) int32 — per-global-q-position window size + Mask: (n_idx <= m_idx) & (m_idx - n_idx <= windows[m_global]) + """ + windows = aux_tensors[0] + + offset_q = seqlen_info.offset_q + m_global = m_idx + offset_q + m_frag = cute.make_fragment(1, cutlass.Int32) + m_frag.store(m_global) + win_frag = cute.make_fragment(1, cutlass.Int32) + win_frag[0] = windows[m_frag[0]] + window = win_frag.load() + + return (n_idx <= m_idx) & ((m_idx - n_idx) <= window) # ============================================================================= @@ -246,6 +337,61 @@ def flex_ima_mask(b, h, q_idx, kv_idx, bias): return kv_idx >= bias[kv_idx] +# ============================================================================= +# Flex reference factories for global-index masks (per-sequence) +# Each factory(seq_idx, sq, sk) -> mask(b, h, q_idx, kv_idx) +# where q_idx/kv_idx are local (0-indexed) within the sequence. +# ============================================================================= + + +def global_packed_doc_flex_factory(doc_ids_q, doc_ids_k, cu_seqlens_q, cu_seqlens_k): + """Factory for per-sequence flex reference of cute_global_packed_doc_mask.""" + + def factory(seq_idx, sq, sk): + q_offset = cu_seqlens_q[seq_idx].item() + k_offset = cu_seqlens_k[seq_idx].item() + + def mask(b, h, q_idx, kv_idx): + q_global = q_offset + q_idx + k_global = k_offset + kv_idx + return doc_ids_q[q_global] == doc_ids_k[k_global] + + return mask + + return factory + + +def global_ima_flex_factory(thresholds, cu_seqlens_k): + """Factory for per-sequence flex reference of cute_global_ima_mask.""" + + def factory(seq_idx, sq, sk): + k_offset = cu_seqlens_k[seq_idx].item() + + def mask(b, h, q_idx, kv_idx): + k_global = k_offset + kv_idx + return kv_idx >= thresholds[k_global] + + return mask + + return factory + + +def global_causal_window_flex_factory(windows, cu_seqlens_q): + """Factory for per-sequence flex reference of cute_global_causal_window_mask.""" + + def factory(seq_idx, sq, sk): + q_offset = cu_seqlens_q[seq_idx].item() + + def mask(b, h, q_idx, kv_idx): + q_global = q_offset + q_idx + window = windows[q_global] + return (kv_idx <= q_idx) & ((q_idx - kv_idx) <= window) + + return mask + + return factory + + # ============================================================================= # Utility functions # ============================================================================= @@ -253,7 +399,9 @@ def flex_ima_mask(b, h, q_idx, kv_idx, bias): def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): """Generate synthetic document ids shared across heads.""" - doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device) + doc_ids_tensor = torch.zeros( + batch, nheads, seqlen_q, dtype=torch.int32, device=device + ) for b in range(batch): N = seqlen_q max_segments = max(1, math.ceil(math.sqrt(max(N // 4, 1)))) @@ -271,6 +419,339 @@ def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): return doc_ids_tensor +def make_packed_doc_ids(seqlens_q, seqlens_k, device="cuda"): + """Generate packed 1D doc ID tensors for Q and K for varlen global-index testing. + + For each sequence, divides tokens into sqrt(len)-ish segments. + Returns (doc_ids_q, doc_ids_k) of shape (total_q,) and (total_k,). + """ + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + doc_ids_q = torch.zeros(total_q, dtype=torch.int32, device=device) + doc_ids_k = torch.zeros(total_k, dtype=torch.int32, device=device) + + q_off = 0 + k_off = 0 + for sq, sk in zip(seqlens_q, seqlens_k): + # Q doc IDs + n_docs = max(1, math.ceil(math.sqrt(max(sq // 4, 1)))) + n_docs = min(n_docs, sq) + if n_docs > 1 and sq > 1: + cuts = sorted(random.sample(range(1, sq), min(n_docs - 1, sq - 1))) + else: + cuts = [] + lengths = [b - a for a, b in zip((0, *cuts), (*cuts, sq))] + doc_ids_q[q_off : q_off + sq] = torch.repeat_interleave( + torch.arange(len(lengths), dtype=torch.int32, device=device), + torch.tensor(lengths, dtype=torch.int32, device=device), + ) + + # K doc IDs (same n_docs range for potential overlap) + if n_docs > 1 and sk > 1: + cuts_k = sorted(random.sample(range(1, sk), min(n_docs - 1, sk - 1))) + else: + cuts_k = [] + lengths_k = [b - a for a, b in zip((0, *cuts_k), (*cuts_k, sk))] + doc_ids_k[k_off : k_off + sk] = torch.repeat_interleave( + torch.arange(len(lengths_k), dtype=torch.int32, device=device), + torch.tensor(lengths_k, dtype=torch.int32, device=device), + ) + + q_off += sq + k_off += sk + + return doc_ids_q, doc_ids_k + + +def make_global_thresholds(seqlens_k, device="cuda"): + """Generate per-global-kv-token thresholds for cute_global_ima_mask. + + For each K token at local index i in a sequence of length sk, + threshold = random value in [0, sk//2]. + Returns thresholds of shape (total_k,). + """ + total_k = sum(seqlens_k) + thresholds = torch.zeros(total_k, dtype=torch.int32, device=device) + k_off = 0 + for sk in seqlens_k: + for i in range(sk): + thresholds[k_off + i] = random.randint(0, max(0, sk // 2)) + k_off += sk + return thresholds + + +def make_global_windows(seqlens_q, device="cuda"): + """Generate per-global-q-token window sizes for cute_global_causal_window_mask. + + For Q token at local index i, window = random value in [0, i] (causal). + Returns windows of shape (total_q,). + """ + total_q = sum(seqlens_q) + windows = torch.zeros(total_q, dtype=torch.int32, device=device) + q_off = 0 + for sq in seqlens_q: + for i in range(sq): + windows[q_off + i] = random.randint(0, i) + q_off += sq + return windows + + +# ============================================================================= +# Vectorized mask_mod variants (return bit-packed Uint32) +# ============================================================================= +# +# Each variant receives shape-(vec_size,) Int32 SSAs and returns a shape- +# (max(1, vec_size // 32),) Uint32 TensorSSA, where bit i of element k is the +# mask for lane (k * 32 + i). Bodies assume lane i has idx[i] = idx[0] + i, so +# they compute the packed Uint32(s) in O(1) via integer/bit arithmetic on the +# chunk-base indices instead of evaluating per lane. + + +@cute.jit +def cute_causal_mask_vec( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors: None, +) -> cute.TensorSSA: + offset = seqlen_info.seqlen_k - seqlen_info.seqlen_q + threshold = m_idx[0] + offset - n_idx[0] + cutlass.Int32(1) + m = max(cutlass.Int32(32) - threshold, cutlass.Int32(0)) + result = cute.make_rmem_tensor(1, dtype=cutlass.Uint32) + result[0] = utils.shr_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(m)) + return result.load() + + +def get_cute_causal_mask_vec(offset: int): + return cute_causal_mask_vec + + +def get_cute_sliding_window_mask_vec(window_left: int, window_right: int, offset: int): + @cute.jit + def _cute_sliding_window_mask_vec( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, + ) -> cute.TensorSSA: + runtime_offset = seqlen_info.seqlen_k - seqlen_info.seqlen_q + center = m_idx[0] + runtime_offset + lo = center - cutlass.Int32(window_left) - n_idx[0] + hi_excl = center + cutlass.Int32(window_right) - n_idx[0] + cutlass.Int32(1) + m_below = max(cutlass.Int32(32) - hi_excl, cutlass.Int32(0)) + below = utils.shr_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(m_below)) + n_above = max(lo, cutlass.Int32(0)) + above = utils.shl_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(n_above)) + result = cute.make_rmem_tensor(1, dtype=cutlass.Uint32) + result[0] = below & above + return result.load() + + return _cute_sliding_window_mask_vec + + +@cute.jit +def cute_block_diagonal_mask_vec( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + block_size = cutlass.Int32(128) + block_m = m_idx[0] // block_size + lo = block_m * block_size - n_idx[0] + hi = (block_m + cutlass.Int32(1)) * block_size - n_idx[0] + m_below = max(cutlass.Int32(32) - hi, cutlass.Int32(0)) + below = utils.shr_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(m_below)) + n_above = max(lo, cutlass.Int32(0)) + above = utils.shl_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(n_above)) + result = cute.make_rmem_tensor(1, dtype=cutlass.Uint32) + result[0] = below & above + return result.load() + + +@cute.jit +def cute_prefix_lm_mask_vec( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + prefix = cutlass.Int32(512) + hi_pref = prefix - n_idx[0] + m_below_pref = max(cutlass.Int32(32) - hi_pref, cutlass.Int32(0)) + term1_below = utils.shr_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(m_below_pref)) + row_in_prefix_mask = ( + cutlass.Uint32(0xFFFFFFFF) if m_idx[0] < prefix else cutlass.Uint32(0) + ) + term1 = term1_below & row_in_prefix_mask + hi_causal = m_idx[0] - n_idx[0] + cutlass.Int32(1) + m_below_causal = max(cutlass.Int32(32) - hi_causal, cutlass.Int32(0)) + term2 = utils.shr_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(m_below_causal)) + result = cute.make_rmem_tensor(1, dtype=cutlass.Uint32) + result[0] = term1 | term2 + return result.load() + + +# ============================================================================= +# Packed-bitmask aux tensor mod +# ============================================================================= +# aux[0] is a (batch, max_seqlen_q, ceil(max_seqlen_k / 32)) Uint32 tensor where +# bit k of packed[b, q, c] is the mask for (b, q, c*32 + k). + + +@cute.jit +def cute_packed_mask_aux( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + packed = aux_tensors[0] + val = packed[batch[0], m_idx[0], n_idx[0] // cutlass.Int32(32)] + shift = cutlass.Uint32(n_idx[0] % cutlass.Int32(32)) + bit_set = cutlass.Boolean(utils.shr_u32(val, shift) & cutlass.Uint32(1)) + result = cute.make_rmem_tensor(n_idx.shape, dtype=cutlass.Boolean) + for j in cutlass.range_constexpr(cute.size(n_idx.shape)): + result[j] = bit_set + return result.load() + + +def get_cute_packed_mask_aux_vec(vec_size: int): + """Vec packed-mask, specialized for `vec_size`. For vec_size > 32, requires + `aux_tensors[0].__assumed_align__ = 16` and num_words divisible by 4.""" + if vec_size <= 32: + + @cute.jit + def _mod( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, + ) -> cute.TensorSSA: + packed = aux_tensors[0] + base = n_idx[0] // cutlass.Int32(32) + val = packed[batch[0], m_idx[0], base] + shift = cutlass.Uint32(n_idx[0] % cutlass.Int32(32)) + result = cute.make_rmem_tensor(1, dtype=cutlass.Uint32) + result[0] = utils.shr_u32(val, shift) + return result.load() + else: + num_words = vec_size // 32 + + @cute.jit + def _mod( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, + ) -> cute.TensorSSA: + packed = aux_tensors[0] + b_str, m_str, _ = packed.stride + packed_aligned = cute.make_tensor( + packed.iterator, + cute.make_layout( + packed.shape, + stride=( + cute.assume(b_str, divby=4), + cute.assume(m_str, divby=4), + 1, + ), + ), + ) + packed_row = packed_aligned[batch[0], m_idx[0], None] + packed_tiled = cute.flat_divide(packed_row, (num_words,)) + base = n_idx[0] // cutlass.Int32(32) + packed_chunk = packed_tiled[None, base // cutlass.Int32(num_words)] + loaded = cute.make_rmem_tensor_like(packed_chunk) + cute.autovec_copy(packed_chunk, loaded) + result = cute.make_rmem_tensor(num_words, dtype=cutlass.Uint32) + for k in cutlass.range_constexpr(num_words): + result[k] = cutlass.Uint32(loaded[k]) + return result.load() + + return _mod + + +def make_packed_mask_aux_tensor( + batch: int, + seqlen_q: int, + seqlen_k: int, + density: float = 0.5, + device="cuda", + seed: int = 0, +): + """Random Uint32 bit-packed mask. num_words is rounded up to a multiple of 4 + so each row is 16-byte aligned (LDG.E.128 requirement at vec_size=128).""" + g = torch.Generator(device=device).manual_seed(seed) + num_words = ((seqlen_k + 31) // 32 + 3) // 4 * 4 + bools = ( + torch.rand(batch, seqlen_q, num_words * 32, device=device, generator=g) + < density + ) + bools = bools.reshape(batch, seqlen_q, num_words, 32) + powers = 1 << torch.arange(32, device=device, dtype=torch.int64) + packed = (bools.to(torch.int64) * powers).sum(-1).to(torch.uint32) + packed.__assumed_align__ = 16 + return packed + + +VEC_MASK_FACTORIES = { + "causal": ("factory", get_cute_causal_mask_vec), + "block_causal": ("factory", get_cute_causal_mask_vec), + "sliding_window": ("factory_window", get_cute_sliding_window_mask_vec), + "block_diagonal": ("static", cute_block_diagonal_mask_vec), + "prefix_lm": ("static", cute_prefix_lm_mask_vec), + "packed_aux": ("factory_vec_size", get_cute_packed_mask_aux_vec), +} + + +# Scalar mods for vec masks not in STATIC_MASKS / PARAMETERIZED_MASK_FACTORIES. +EXTRA_SCALAR_MASKS = { + "packed_aux": cute_packed_mask_aux, +} + + +def get_vec_mask( + mask_name, seqlen_q=None, seqlen_k=None, window_size=None, vec_size=None +): + """Return a vectorized cute mask callable for `mask_name`, or None if there + is no vec form. Caller sets `__vec_size__` on the returned callable. + `vec_size` is required for masks whose body specializes on it (packed_aux).""" + if mask_name not in VEC_MASK_FACTORIES: + return None + kind, obj = VEC_MASK_FACTORIES[mask_name] + if kind == "static": + return obj + if kind == "factory_vec_size": + if vec_size is None: + raise ValueError(f"{mask_name} vec mask requires vec_size") + return obj(vec_size) + offset = ( + (seqlen_k - seqlen_q) if (seqlen_q is not None and seqlen_k is not None) else 0 + ) + if kind == "factory": + return obj(offset) + if kind == "factory_window": + if window_size is None: + raise ValueError("sliding_window vec mask requires window_size") + return obj(window_size, window_size, offset) + raise ValueError(f"unknown vec mask kind: {kind}") + + # ============================================================================= # Mask registry and factory functions # ============================================================================= diff --git a/tests/cute/test_block_sparsity.py b/tests/cute/test_block_sparsity.py index 18d578d080f..d4624c8f014 100644 --- a/tests/cute/test_block_sparsity.py +++ b/tests/cute/test_block_sparsity.py @@ -24,7 +24,7 @@ def _call_compute_block_sparsity( cute_mask, _ = get_mask_pair( mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size ) - _, torch_tensors = compute_block_sparsity( + torch_tensors = compute_block_sparsity( tile_m=tile_m, tile_n=tile_n, batch_size=batch_size, @@ -481,5 +481,382 @@ def test_fast_sampling(seqlen_q, seqlen_k, tile_m, tile_n, nheads, mask_name): assert all_match, f"Mismatch: {error_msg}" +def _compare_block_sparsity_varlen( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + seqlens_q, + seqlens_k, + nheads, + tile_m, + tile_n, + mask_name, + window_size=None, +): + """Compare varlen block sparsity against per-sequence fixed-length references.""" + batch_size = len(seqlens_q) + cu_m = cu_total_m_blocks.cpu().tolist() + cu_n = cu_block_idx_offsets.cpu().tolist() + + for b in range(batch_size): + sq, sk = seqlens_q[b], seqlens_k[b] + num_m = (sq + tile_m - 1) // tile_m + num_n = (sk + tile_n - 1) // tile_n + m_off = cu_m[b] + n_off = cu_n[b] + + _, mask_mod_flex = get_mask_pair( + mask_name, seqlen_q=sq, seqlen_k=sk, window_size=window_size + ) + block_mask = create_block_mask( + mask_mod_flex, + B=1, + H=nheads, + Q_LEN=sq, + KV_LEN=sk, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + _, _, ref_mask_cnt, ref_mask_idx, ref_full_cnt, ref_full_idx, *_ = ( + block_mask.as_tuple() + ) + + for h in range(nheads): + for m in range(num_m): + global_m = m_off + m + n_base = n_off + m * num_n + + vl_mask_cnt = mask_block_cnt[h, global_m].item() + vl_full_cnt = full_block_cnt[h, global_m].item() + vl_mask_set = set( + mask_block_idx[h, n_base : n_base + vl_mask_cnt].tolist() + ) + vl_full_set = set( + full_block_idx[h, n_base : n_base + vl_full_cnt].tolist() + ) + + r_mask_cnt = ref_mask_cnt[0, h, m].item() + r_full_cnt = ref_full_cnt[0, h, m].item() + r_mask_set = set(ref_mask_idx[0, h, m, :r_mask_cnt].tolist()) + r_full_set = set(ref_full_idx[0, h, m, :r_full_cnt].tolist()) + + last_m_block = (sq - 1) // tile_m + last_n_block = (sk - 1) // tile_n + m_is_boundary = sq % tile_m != 0 and m == last_m_block + n_is_boundary = sk % tile_n != 0 + + def is_boundary_affected( + n_block, + _m_bnd=m_is_boundary, + _n_bnd=n_is_boundary, + _ln=last_n_block, + ): + return _m_bnd or (_n_bnd and n_block == _ln) + + non_boundary_vl_full = { + n for n in vl_full_set if not is_boundary_affected(n) + } + non_boundary_ref_full = { + n for n in r_full_set if not is_boundary_affected(n) + } + if non_boundary_vl_full != non_boundary_ref_full: + return False, ( + f"Varlen full block mismatch at batch={b}, head={h}, m_block={m} " + f"(sq={sq}, sk={sk}): " + f"varlen={sorted(non_boundary_vl_full)}, ref={sorted(non_boundary_ref_full)}" + ) + + non_boundary_vl_mask = { + n for n in vl_mask_set if not is_boundary_affected(n) + } + non_boundary_ref_mask = { + n for n in r_mask_set if not is_boundary_affected(n) + } + if non_boundary_vl_mask != non_boundary_ref_mask: + return False, ( + f"Varlen partial block mismatch at batch={b}, head={h}, m_block={m} " + f"(sq={sq}, sk={sk}): " + f"varlen={sorted(non_boundary_vl_mask)}, ref={sorted(non_boundary_ref_mask)}" + ) + + return True, "" + + +# ---- Varlen test configurations ---- + +VARLEN_SEQLEN_CONFIGS = [ + # (seqlens_q, seqlens_k) - lists of per-batch lengths + # Uniform lengths (should match fixed-length behavior) + ([128, 128], [128, 128]), + ([256, 256], [256, 256]), + # Different lengths per batch + ([64, 128], [64, 128]), + ([128, 256], [128, 256]), + ([256, 512], [256, 512]), + ([64, 128, 256], [64, 128, 256]), + # Unaligned + ([113, 203], [113, 203]), + ([127, 255], [127, 255]), + ([100, 200, 300], [100, 200, 300]), + # Asymmetric Q/K + ([128, 256], [256, 128]), + ([64, 128], [128, 256]), + # Single element sequences + ([1, 128], [1, 128]), + ([64, 1], [64, 1]), + # Large spread + ([32, 512, 128], [32, 512, 128]), + ([1024, 64], [1024, 64]), +] + + +def _generate_varlen_inputs( + seqlens_q, + seqlens_k, + tile_m, + tile_n, + device="cuda", +): + """Generate cu_seqlens and cu_total_*_blocks for a varlen batch. + + Args: + seqlens_q: list of per-batch query sequence lengths + seqlens_k: list of per-batch key sequence lengths + tile_m, tile_n: tile sizes + Returns: + cu_seqlens_q, cu_seqlens_k, cu_total_m_blocks, cu_block_idx_offsets + """ + batch_size = len(seqlens_q) + assert len(seqlens_k) == batch_size + + cu_seqlens_q = [0] + cu_seqlens_k = [0] + cu_total_m_blocks = [0] + cu_block_idx_offsets = [0] + + for b in range(batch_size): + cu_seqlens_q.append(cu_seqlens_q[-1] + seqlens_q[b]) + cu_seqlens_k.append(cu_seqlens_k[-1] + seqlens_k[b]) + num_m = (seqlens_q[b] + tile_m - 1) // tile_m + num_n = (seqlens_k[b] + tile_n - 1) // tile_n + cu_total_m_blocks.append(cu_total_m_blocks[-1] + num_m) + cu_block_idx_offsets.append(cu_block_idx_offsets[-1] + num_m * num_n) + + return ( + torch.tensor(cu_seqlens_q, device=device, dtype=torch.int32), + torch.tensor(cu_seqlens_k, device=device, dtype=torch.int32), + torch.tensor(cu_total_m_blocks, device=device, dtype=torch.int32), + torch.tensor(cu_block_idx_offsets, device=device, dtype=torch.int32), + ) + + +def _call_compute_block_sparsity_varlen( + seqlens_q, + seqlens_k, + nheads, + tile_m, + tile_n, + mask_name, + window_size=None, + aux_tensors=None, + use_fast_sampling=False, +): + """Call compute_block_sparsity with varlen inputs.""" + batch_size = len(seqlens_q) + # Use max seqlens for mask_mod compilation (the kernel uses per-batch seqlens at runtime) + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + + cute_mask, _ = get_mask_pair( + mask_name, seqlen_q=max_seqlen_q, seqlen_k=max_seqlen_k, window_size=window_size + ) + + cu_seqlens_q, cu_seqlens_k, cu_total_m_blocks, cu_block_idx_offsets = ( + _generate_varlen_inputs(seqlens_q, seqlens_k, tile_m, tile_n) + ) + + torch_tensors = compute_block_sparsity( + tile_m=tile_m, + tile_n=tile_n, + batch_size=batch_size, + num_heads=nheads, + seqlen_q=max_seqlen_q, + seqlen_k=max_seqlen_k, + mask_mod=cute_mask, + aux_tensors=aux_tensors, + device="cuda", + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + cu_total_m_blocks=cu_total_m_blocks, + cu_block_idx_offsets=cu_block_idx_offsets, + use_fast_sampling=use_fast_sampling, + ) + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = torch_tensors + return ( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + ) + + +@pytest.mark.parametrize("seqlens_q,seqlens_k", VARLEN_SEQLEN_CONFIGS) +@pytest.mark.parametrize("tile_m,tile_n", [(64, 64), (128, 128), (64, 128), (128, 64)]) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize("mask_name", ["causal", "block_diagonal"]) +def test_varlen(seqlens_q, seqlens_k, tile_m, tile_n, nheads, mask_name): + """Test variable-length sequence support.""" + ( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + ) = _call_compute_block_sparsity_varlen( + seqlens_q, seqlens_k, nheads, tile_m, tile_n, mask_name + ) + + all_match, error_msg = _compare_block_sparsity_varlen( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + seqlens_q, + seqlens_k, + nheads, + tile_m, + tile_n, + mask_name, + ) + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize( + "seqlens_q,seqlens_k", + [ + ([128, 128], [128, 128]), + ([64, 128, 256], [64, 128, 256]), + ([100, 200], [100, 200]), + ], +) +@pytest.mark.parametrize("tile_m,tile_n", [(64, 64), (128, 128)]) +@pytest.mark.parametrize("nheads", [1]) +@pytest.mark.parametrize( + "mask_name,window_size", + [("causal", None), ("sliding_window", 64), ("sliding_window", 256)], +) +def test_varlen_parameterized_masks( + seqlens_q, seqlens_k, tile_m, tile_n, nheads, mask_name, window_size +): + """Test varlen with parameterized masks.""" + # Skip sliding window when any seqlen_q > seqlen_k + if mask_name == "sliding_window" and any( + sq > sk for sq, sk in zip(seqlens_q, seqlens_k) + ): + pytest.skip("Sliding window not supported for seqlen_q > seqlen_k") + + ( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + ) = _call_compute_block_sparsity_varlen( + seqlens_q, seqlens_k, nheads, tile_m, tile_n, mask_name, window_size=window_size + ) + + all_match, error_msg = _compare_block_sparsity_varlen( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + seqlens_q, + seqlens_k, + nheads, + tile_m, + tile_n, + mask_name, + window_size=window_size, + ) + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize("tile_m,tile_n", [(64, 64), (128, 128)]) +def test_varlen_matches_fixed_length(nheads, tile_m, tile_n): + """Verify that varlen with uniform sequence lengths produces identical + results to the fixed-length path.""" + seqlen_q, seqlen_k = 256, 256 + batch_size = 3 + mask_name = "causal" + + # Fixed-length result + fixed_mask_cnt, fixed_mask_idx, fixed_full_cnt, fixed_full_idx = ( + _call_compute_block_sparsity( + batch_size, nheads, seqlen_q, seqlen_k, tile_m, tile_n, mask_name + ) + ) + + # Varlen with uniform lengths + seqlens_q = [seqlen_q] * batch_size + seqlens_k = [seqlen_k] * batch_size + ( + vl_mask_cnt, + vl_mask_idx, + vl_full_cnt, + vl_full_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + ) = _call_compute_block_sparsity_varlen( + seqlens_q, seqlens_k, nheads, tile_m, tile_n, mask_name + ) + + num_m = (seqlen_q + tile_m - 1) // tile_m + num_n = (seqlen_k + tile_n - 1) // tile_n + cu_m = cu_total_m_blocks.cpu().tolist() + cu_n = cu_block_idx_offsets.cpu().tolist() + + for b in range(batch_size): + for h in range(nheads): + for m in range(num_m): + global_m = cu_m[b] + m + n_base = cu_n[b] + m * num_n + + # Counts should match + assert ( + vl_mask_cnt[h, global_m].item() == fixed_mask_cnt[b, h, m].item() + ), f"Mask count mismatch at b={b}, h={h}, m={m}" + assert ( + vl_full_cnt[h, global_m].item() == fixed_full_cnt[b, h, m].item() + ), f"Full count mismatch at b={b}, h={h}, m={m}" + + mc = vl_mask_cnt[h, global_m].item() + fc = vl_full_cnt[h, global_m].item() + vl_mask_set = set(vl_mask_idx[h, n_base : n_base + mc].tolist()) + vl_full_set = set(vl_full_idx[h, n_base : n_base + fc].tolist()) + fixed_mask_set = set(fixed_mask_idx[b, h, m, :mc].tolist()) + fixed_full_set = set(fixed_full_idx[b, h, m, :fc].tolist()) + + assert vl_mask_set == fixed_mask_set, ( + f"Mask idx mismatch at b={b}, h={h}, m={m}: " + f"varlen={sorted(vl_mask_set)}, fixed={sorted(fixed_mask_set)}" + ) + assert vl_full_set == fixed_full_set, ( + f"Full idx mismatch at b={b}, h={h}, m={m}: " + f"varlen={sorted(vl_full_set)}, fixed={sorted(fixed_full_set)}" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 57c98134bca..d980d127fb5 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -154,8 +154,6 @@ def test_flash_attn_output( pytest.skip("SM100 head_dim=256 2CTA kernel does not support softcap yet") if deterministic: pytest.skip("SM100 head_dim=256 2CTA kernel does not support deterministic mode yet") - if causal and seqlen_q > seqlen_k: - pytest.skip("SM100 head_dim=256 2CTA kernel does not support causal attention with seqlen_q > seqlen_k yet") device = "cuda" # set seed seed = 0 @@ -439,6 +437,50 @@ def test_flash_attn_output( ).abs().max().item() + dv_atol +# Regression test for #2591: SMEM overflow at small head_dims on SM100. The main +# test_flash_attn_output skips d < 64, but _validate_head_dims accepts head_dim >= 8 +# for sm_100/110, so this path needs coverage. Trigger requires +# seqlen_q_packgqa > tile_m to push q_stage 1->2. +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [8, 16, 32]) +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (2048, 2048)]) +@retry_on_oom +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_small_head_dim(seqlen_q, seqlen_k, d, causal, dtype): + device = "cuda" + seed = 0 + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.empty_cache() + torch.cuda.synchronize() + batch_size = 2 + nheads = 2 + nheads_kv = nheads + dtype_ref = dtype + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ).requires_grad_() + k_ref = torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ).requires_grad_() + v_ref = torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ).requires_grad_() + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + out_ref, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal) + out_pt, _ = attention_ref( + q_ref, k_ref, v_ref, None, None, causal=causal, upcast=False, reorder_ops=True + ) + out, _ = flash_attn_func(q, k, v, causal=causal) + if is_fake_mode(): + return + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert (out - out_ref).abs().max().item() <= 2 * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @@ -550,10 +592,6 @@ def test_flash_attn_varlen_output( pytest.skip("SM100 head_dim=256 2CTA kernel does not support softcap yet") if deterministic: pytest.skip("SM100 head_dim=256 2CTA kernel does not support deterministic mode yet") - if causal and seqlen_q > seqlen_k: - pytest.skip("SM100 head_dim=256 2CTA kernel does not support causal attention with seqlen_q > seqlen_k yet") - if zero_lengths_q or zero_lengths_k: - pytest.skip("SM100 head_dim=256 2CTA kernel does not support zero-length sequences yet") if not unpad_q or not unpad_kv: pytest.skip("SM100 head_dim=256 2CTA kernel does not support seqused_q/seqused_k mode yet (requires unpad_q=True and unpad_kv=True)") if ( @@ -841,6 +879,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): or (IS_SM100 and d == 256 and dv == 256) ) and not has_learnable_sink + and softcap == 0.0 # TODO: support softcap != 0.0 in varlen bwd # and False ): if d > 192 and IS_SM90: @@ -2739,3 +2778,103 @@ def test_flash_attn_seqlen_k_zero(seqlen_q, d, causal): if lse is not None: assert torch.all(torch.isinf(lse) & (lse < 0)).item(), \ f"Expected all -inf LSE when seqlen_k=0, got: {lse}" + + +# --------------------------------------------------------------------------- +# Regression test (#2503): empty Q workload must not crash with ZeroDivisionError +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "shape", + [ + # (batch_size, seqlen_q) — the product drives num_splits_heuristic's divisor. + (2, 0), # seqlen_q == 0 (common in CUDA graph padding) + (0, 64), # batch_size == 0 (empty microbatch) + ], +) +@pytest.mark.parametrize("causal", [False, True]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_empty_q_dense(shape, causal): + """Dense (non-varlen) path must not raise ZeroDivisionError when Q is empty. + + num_splits_heuristic divides num_SMs by total_mblocks = batch_size * num_head_kv + * num_m_blocks. When seqlen_q == 0 or batch_size == 0, total_mblocks collapses + to 0 and the division crashes. The existing seqlen_k == 0 early-exit does not + cover this case. Regression for + https://github.com/Dao-AILab/flash-attention/issues/2503. + """ + batch_size, seqlen_q = shape + device = "cuda" + dtype = torch.bfloat16 + d = 128 + nheads = 16 + nheads_kv = 16 + # Pick seqlen_k large enough that num_n_blocks > 4 so the heuristic's own + # `num_n_blocks <= 4` early-return does not mask the bug. + seqlen_k = 4096 + + torch.manual_seed(0) + + q = torch.empty(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + + out, lse = flash_attn_func(q, k, v, causal=causal) + + if is_fake_mode(): + return + + assert out.shape == (batch_size, seqlen_q, nheads, d), \ + f"Unexpected output shape: {out.shape}" + # With zero elements these are vacuously true, but we still want to assert + # the function returned cleanly rather than erroring out. + assert out.numel() == 0 + if lse is not None: + assert lse.numel() == 0 + + +@pytest.mark.parametrize("causal", [False, True]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_empty_q_varlen(causal): + """Varlen path must not raise ZeroDivisionError when total_q == 0. + + Parallels test_flash_attn_empty_q_dense for the cu_seqlens_q path, where + total_q = q.shape[0] feeds into total_mblocks. Regression for + https://github.com/Dao-AILab/flash-attention/issues/2503. + """ + device = "cuda" + dtype = torch.bfloat16 + d = 128 + nheads = 16 + nheads_kv = 16 + batch_size = 2 + seqlen_k_per_batch = 2048 + + torch.manual_seed(0) + + # All zero-length Q sequences — total_q == 0 while cu_seqlens_q is well-formed. + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_k = torch.tensor( + [0, seqlen_k_per_batch, 2 * seqlen_k_per_batch], + dtype=torch.int32, device=device, + ) + total_k = int(cu_seqlens_k[-1].item()) + + q = torch.empty(0, nheads, d, device=device, dtype=dtype) + k = torch.randn(total_k, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(total_k, nheads_kv, d, device=device, dtype=dtype) + + out, lse = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=0, max_seqlen_k=seqlen_k_per_batch, + causal=causal, + ) + + if is_fake_mode(): + return + + assert out.shape == (0, nheads, d), f"Unexpected output shape: {out.shape}" + assert out.numel() == 0 + if lse is not None: + assert lse.numel() == 0 diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index ceef6500b97..a4228dc48a0 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -32,7 +32,13 @@ ) from flash_attn.cute.cache_utils import get_jit_cache from flash_attn.cute import utils -from mask_mod_definitions import get_mask_pair, random_doc_id_tensor +from mask_mod_definitions import ( + get_mask_pair, + get_vec_mask, + random_doc_id_tensor, + EXTRA_SCALAR_MASKS, + make_packed_mask_aux_tensor, +) COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] @@ -47,6 +53,7 @@ def reset_torch_state(): torch._dynamo.reset() torch.cuda.empty_cache() + def create_tensors( batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype ): @@ -295,10 +302,11 @@ def _run_mask_test( doc_ids = random_doc_id_tensor(nheads, batch_size, doc_len, device="cuda").to( dtype=torch.int32, device="cuda" ) - original_flex_mask = mask_mod_flex + doc_ids.__leading_dim__ = 2 + doc_row = doc_ids[0, 0] # (doc_len,); batch_size=1, all heads identical - def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): - return original_flex_mask(b, h, q_idx, kv_idx, doc_ids) + def mask_mod_flex(b, h, q_idx, kv_idx): + return doc_row[q_idx] == doc_row[kv_idx] aux_tensors_arg = [doc_ids] elif mask_name == "ima": @@ -827,6 +835,122 @@ def test_parameterized_masks( ) +# ============================================================================= +# Vectorized mask_mod equality tests +# Pattern: scalar mask is reference; vec mask at multiple __vec_size__ values +# must produce bit-identical output. +# ============================================================================= + +# (mask_name, window_size, needs_aux) +VEC_MASK_TEST_CASES = [ + ("causal", None, False), + ("block_causal", None, False), + ("sliding_window", 128, False), + ("block_diagonal", None, False), + ("prefix_lm", None, False), + ("packed_aux", None, True), +] + +# Vectorized mask_mod application is currently implemented for SM100/SM110 forward. +# vec_size > 32 is only supported by packed_aux (other vec mods return shape-(1,) Uint32). +VEC_MASK_SIZES_TO_CHECK_EQUALITY = [2, 8, 32, 128] + + +def _run_mask_mod_only(q, k, v, mask_mod, aux_tensors, pack_gqa): + out = torch.empty_like(q) + _flash_attn_fwd( + q=q, + k=k, + v=v, + out=out, + lse=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=1.0 / math.sqrt(q.shape[-1]), + causal=False, + softcap=None, + window_size_left=-1, + window_size_right=-1, + learnable_sink=None, + tile_mn=(128, 128), + pack_gqa=pack_gqa, + _arch=None, + score_mod=None, + mask_mod=mask_mod, + block_sparse_tensors=None, + return_lse=False, + aux_tensors=aux_tensors, + ) + return out + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [(128, 128), (256, 256), (113, 203), (256, 512), (1024, 1024)], +) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 4), (4, 2)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("mask_case", VEC_MASK_TEST_CASES) +def test_cute_mask_mod_vectorized( + seqlen_q, seqlen_k, qhead_per_kvhead, num_kv_heads, dtype, mask_case +): + """Tests equality between scalar and vectorized versions of mask mods.""" + if COMPUTE_CAPABILITY not in (10, 11): + pytest.skip("vectorized mask_mod application is SM100/SM110-only") + mask_name, window_size, needs_aux = mask_case + torch.manual_seed(42) + + num_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + head_dim = 128 + batch_size = 2 + + q = torch.randn( + batch_size, seqlen_q, num_heads, head_dim, device="cuda", dtype=dtype + ) + k = torch.randn( + batch_size, seqlen_k, num_kv_heads, head_dim, device="cuda", dtype=dtype + ) + v = torch.randn( + batch_size, seqlen_k, num_kv_heads, head_dim, device="cuda", dtype=dtype + ) + + if needs_aux: + aux_tensors = [make_packed_mask_aux_tensor(batch_size, seqlen_q, seqlen_k)] + scalar_mod = EXTRA_SCALAR_MASKS[mask_name] + else: + aux_tensors = None + scalar_mod, _ = get_mask_pair( + mask_name, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + window_size=window_size, + ) + + out_ref = _run_mask_mod_only(q, k, v, scalar_mod, aux_tensors, pack_gqa) + + for vec_size in VEC_MASK_SIZES_TO_CHECK_EQUALITY: + if vec_size > 32 and mask_name != "packed_aux": + continue + vec_mod = get_vec_mask( + mask_name, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + window_size=window_size, + vec_size=vec_size, + ) + if vec_mod is None: + pytest.skip(f"no vec mask for {mask_name}") + vec_mod.__vec_size__ = vec_size + out = _run_mask_mod_only(q, k, v, vec_mod, aux_tensors, pack_gqa) + assert torch.equal(out, out_ref), ( + f"{mask_name} vec_size={vec_size}: output mismatch vs scalar reference" + ) + + def test_sm100_block_sparse_sink_all_masked(): """Block-sparse regression for the sink path""" if torch.cuda.get_device_capability()[0] != 10: @@ -2257,6 +2381,146 @@ def doc_mask(b, h, q_idx, kv_idx): _run_write_order_test(doc_mask, seqlen_q, seqlen_k, block_size=128, B=B, H=H, spt=spt) +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="SM100/SM110 SplitKV block sparse forward only") +def test_block_sparse_splitkv_matches_unsplit(): + torch.manual_seed(123) + batch_size = 1 + nheads = 4 + seqlen = 2048 + headdim = 64 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + sparse_tile_m = 2 * tile_m + + mask_mod_cute, mask_mod_flex = get_mask_pair("causal", seqlen_q=seqlen, seqlen_k=seqlen) + tensors = create_tensors( + batch_size, seqlen, seqlen, nheads, nheads, headdim, headdim, dtype + ) + + bm = create_block_mask( + mask_mod_flex, + batch_size, + nheads, + seqlen, + seqlen, + device="cuda", + BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + (_, _, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_) = bm.as_tuple() + block_sparse_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(sparse_tile_m, tile_n), + ) + + out_unsplit, lse_unsplit = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"].clone(), + lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_fwd, + num_splits=1, + return_lse=True, + ) + out_split, lse_split = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"].clone(), + lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_fwd, + num_splits=3, + return_lse=True, + ) + + out_ref = compute_reference_flex_attn(tensors, mask_mod_flex, block_size=(sparse_tile_m, tile_n)) + out_ref_fp32 = compute_reference_flex_attn( + {name: tensor.float() for name, tensor in tensors.items()}, + mask_mod_flex, + block_size=(sparse_tile_m, tile_n), + ) + + assert_fwd_matches_reference(out_split, out_ref_fp32, out_ref) + assert torch.allclose(lse_split, lse_unsplit, atol=2e-3, rtol=2e-3) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="SM100/SM110 SplitKV block sparse forward only") +def test_block_sparse_splitkv_oversplit_sparse_blocks(): + torch.manual_seed(321) + batch_size = 1 + nheads = 4 + seqlen_q = 513 + seqlen_k = 1024 + headdim = 64 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + sparse_tile_m = 2 * tile_m + + mask_mod_cute, mask_mod_flex = get_mask_pair("block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k) + tensors = create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads, headdim, headdim, dtype + ) + bm = create_block_mask( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + (_, _, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_) = bm.as_tuple() + block_sparse_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(sparse_tile_m, tile_n), + ) + + out_unsplit, _ = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"].clone(), + lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_fwd, + num_splits=1, + return_lse=True, + ) + out_split, _ = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"].clone(), + lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_fwd, + num_splits=8, + return_lse=True, + ) + + assert not torch.isnan(out_split).any() + assert torch.isfinite(out_split).all() + assert torch.allclose(out_split, out_unsplit, atol=4e-3, rtol=4e-3) + + def test_compact_block_sparse_indices(): """Test that compact block sparse index tensors (idx.shape[3] < n_blocks) work correctly. diff --git a/tests/cute/test_mask_mod_varlen.py b/tests/cute/test_mask_mod_varlen.py new file mode 100644 index 00000000000..6e37e9ed4b8 --- /dev/null +++ b/tests/cute/test_mask_mod_varlen.py @@ -0,0 +1,1173 @@ +# mask_mod varlen test script +# Forward-only +# +# Since flex_attention doesn't support varlen natively, we compare +# results sequence-by-sequence: run the kernel with cu_seqlens (packed), +# then run flex_attention per-sequence and compare. +# +# Usage: +# pytest test_mask_mod_varlen.py -v -s + +import math +import random + +import pytest +import torch +import torch.nn.functional as F +import cutlass +import cutlass.cute as cute +from torch.nn.attention.flex_attention import create_block_mask, flex_attention + +from flash_attn.cute.interface import _flash_attn_fwd +from flash_attn.cute import utils +from flash_attn.cute.compute_block_sparsity import compute_block_sparsity +from mask_mod_definitions import ( + get_mask_pair, + get_vec_mask, + random_doc_id_tensor, + STATIC_MASKS, + PARAMETERIZED_MASK_FACTORIES, + EXTRA_SCALAR_MASKS, + make_packed_mask_aux_tensor, + cute_global_packed_doc_mask, + cute_global_ima_mask, + cute_global_causal_window_mask, + global_packed_doc_flex_factory, + global_ima_flex_factory, + global_causal_window_flex_factory, + make_packed_doc_ids, + make_global_thresholds, + make_global_windows, +) + +COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] + + +@pytest.fixture(autouse=True) +def reset_torch_state(): + """Reset torch dynamo/compile state between tests to avoid state pollution.""" + torch._dynamo.reset() + torch.cuda.empty_cache() + yield + torch._dynamo.reset() + torch.cuda.empty_cache() + + +# ============================================================================= +# Seqlen configs for varlen (list of per-sequence lengths) +# ============================================================================= + +SEQLEN_CONFIGS = [ + # Simple cases + ([1], [1]), + ([64], [64]), + ([128], [128]), + # Multiple sequences, same length + ([128, 128], [128, 128]), + ([64, 64, 64], [64, 64, 64]), + # Multiple sequences, varying lengths + ([64, 128], [64, 128]), + ([32, 64, 128], [32, 64, 128]), + ([113, 203], [113, 203]), + ([256, 512], [256, 512]), + # Asymmetric Q/K lengths + ([64, 128], [32, 64]), + ([100, 100], [50, 50]), + # Edge cases + ([1, 1], [1, 1]), + ([1, 256], [1, 256]), + ([256, 1], [256, 1]), + ([17, 33, 65], [17, 33, 65]), + # Larger sequences + ([1024, 1024], [1024, 1024]), + ([256, 512, 256], [128, 256, 128]), +] + +SEQLEN_CONFIGS_SMOKE = [ + ([128, 128], [128, 128]), + ([64, 128], [64, 128]), + ([113, 203], [113, 203]), + ([256, 512], [256, 512]), + ([64, 128], [32, 64]), +] + + +# ============================================================================= +# Helper functions +# ============================================================================= + + +def setup_varlen_tensors( + seqlens_q, seqlens_k, num_heads, num_kv_heads, head_dim, dtype +): + """Create packed Q, K, V tensors and cu_seqlens for varlen.""" + device = "cuda" + batch_size = len(seqlens_q) + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + + q = torch.randn(total_q, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(total_k, num_kv_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(total_k, num_kv_heads, head_dim, device=device, dtype=dtype) + + cu_seqlens_q = torch.tensor( + [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + cu_seqlens_k = torch.tensor( + [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + + return q, k, v, cu_seqlens_q, cu_seqlens_k + + +def run_flex_per_sequence( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + mask_mod_flex_factory, + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype=None, +): + """Run flex_attention per-sequence as reference for varlen. + + mask_mod_flex_factory(seq_idx, seqlen_q_i, seqlen_k_i) -> mask_mod function + that takes (b, h, q_idx, kv_idx) for that sequence. + """ + batch_size = len(seqlens_q) + results = [] + + for i in range(batch_size): + sq = seqlens_q[i] + sk = seqlens_k[i] + + # Extract packed slices + q_slice = q[cu_seqlens_q[i] : cu_seqlens_q[i + 1]].unsqueeze(0) # (1, sq, H, D) + k_slice = k[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze( + 0 + ) # (1, sk, Hkv, D) + v_slice = v[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze(0) + + if dtype is not None: + q_slice = q_slice.to(dtype) + k_slice = k_slice.to(dtype) + v_slice = v_slice.to(dtype) + + # Transpose to (B, H, S, D) for flex_attention + q_t = q_slice.transpose(1, 2) + k_t = k_slice.transpose(1, 2) + v_t = v_slice.transpose(1, 2) + + # Expand KV heads for GQA + if num_heads != num_kv_heads: + repeat_factor = num_heads // num_kv_heads + k_t = k_t.repeat_interleave(repeat_factor, dim=1) + v_t = v_t.repeat_interleave(repeat_factor, dim=1) + + scale = 1.0 / math.sqrt(head_dim) + + mask_mod = mask_mod_flex_factory(i, sq, sk) + + if mask_mod is None: + out = F.scaled_dot_product_attention(q_t, k_t, v_t, scale=scale) + else: + block_mask = create_block_mask( + mask_mod, + B=1, + H=num_heads, + Q_LEN=sq, + KV_LEN=sk, + device=q.device, + ) + out = flex_attention( + q_t, k_t, v_t, block_mask=block_mask, scale=scale, enable_gqa=True + ) + + results.append(out.transpose(1, 2).squeeze(0)) # back to (sq, H, D) + + return torch.cat(results, dim=0) + + +def check_varlen_results( + out_cute, + out_ref_fp32, + out_pt, + seqlens_q, + cu_seqlens_q, + test_name, + rtol=2, + extra_atol=2e-3, +): + """Compare CuTE output against per-sequence flex references.""" + assert not torch.isnan(out_cute).any(), f"{test_name}: NaN in output" + assert torch.isfinite(out_cute).all(), f"{test_name}: Inf in output" + assert out_cute.shape == out_ref_fp32.shape, ( + f"{test_name}: Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}" + ) + + num_seqs = len(seqlens_q) + max_cute_error = 0.0 + max_pt_error = 0.0 + + for i in range(num_seqs): + start = cu_seqlens_q[i] + end = cu_seqlens_q[i + 1] + cute_seq = out_cute[start:end] + ref_seq = out_ref_fp32[start:end] + pt_seq = out_pt[start:end] + + max_cute_error = max(max_cute_error, (cute_seq - ref_seq).abs().max().item()) + max_pt_error = max(max_pt_error, (pt_seq - ref_seq).abs().max().item()) + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + + print(f"\n{test_name}:") + print(f" PyTorch vs FP32 ref: {max_pt_error:.2e}") + print(f" CuTE vs FP32 ref: {max_cute_error:.2e}") + + tol = rtol * max_pt_error + fwd_atol + extra_atol + assert max_cute_error <= tol, ( + f"{test_name}: CuTE error {max_cute_error:.2e} exceeds tolerance {tol:.2e} " + f"(rtol={rtol} * pt_err={max_pt_error:.2e} + fwd_atol={fwd_atol:.2e} + extra={extra_atol:.2e})" + ) + + +# ============================================================================= +# Core test runner +# ============================================================================= + + +def _run_varlen_mask_test( + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype, + mask_name, + window_size=None, +): + """Run a varlen mask_mod test: kernel with cu_seqlens vs per-sequence flex_attention.""" + torch.manual_seed(42) + random.seed(42) + + batch_size = len(seqlens_q) + pack_gqa = num_heads != num_kv_heads + + if mask_name == "sliding_window": + # Skip configs where any seqlen_q > seqlen_k + for sq, sk in zip(seqlens_q, seqlens_k): + if sq > sk: + pytest.skip( + "sliding_window requires seqlen_q <= seqlen_k for each sequence" + ) + + q, k, v, cu_seqlens_q, cu_seqlens_k = setup_varlen_tensors( + seqlens_q, seqlens_k, num_heads, num_kv_heads, head_dim, dtype + ) + + if mask_name == "block_causal": + offsets = [sk - sq for sq, sk in zip(seqlens_q, seqlens_k)] + if len(set(offsets)) > 1: + pytest.skip( + "block_causal captures offset as compile-time constant; " + "varlen with different per-sequence offsets not supported" + ) + + aux_tensors_arg = None + + if mask_name == "document": + max_seqlen = max(max(seqlens_q), max(seqlens_k)) + max_doc_len = max(max(seqlens_q), max(seqlens_k)) + doc_ids = random_doc_id_tensor( + num_heads, batch_size, max_doc_len, device="cuda" + ).to(dtype=torch.int32, device="cuda") + doc_ids.__leading_dim__ = 2 + aux_tensors_arg = [doc_ids] + + from mask_mod_definitions import flex_document_mask + + cute_mask_mod = get_mask_pair("document")[0] + + def flex_factory(seq_idx, sq, sk, doc_ids=doc_ids): + # Pre-slice to 1D using Python ints *outside* the vmapped closure. + # create_block_mask vmaps over all four args (b, h, q_idx, kv_idx); + # multi-dim indexing like doc_id[b, h, q_idx] with 0-dim vmap tensors + # triggers .item() internally. 1D tensor[0d_tensor] is a safe gather. + doc_row = doc_ids[seq_idx, 0] # (max_doc_len,) + + def _mask(b, h, q_idx, kv_idx): + return doc_row[q_idx] == doc_row[kv_idx] + + return _mask + + elif mask_name == "ima": + total_k = sum(seqlens_k) + pytest.skip( + "IMA mask requires global index handling for varlen - not yet implemented" + ) + + else: + if mask_name in STATIC_MASKS: + cute_mask_mod = get_mask_pair(mask_name)[0] + + def flex_factory(seq_idx, sq, sk): + return get_mask_pair(mask_name)[1] + + elif mask_name in PARAMETERIZED_MASK_FACTORIES: + cute_mask_mod = get_mask_pair( + mask_name, + seqlen_q=seqlens_q[0], + seqlen_k=seqlens_k[0], + window_size=window_size, + )[0] + + def flex_factory(seq_idx, sq, sk): + _, flex_mask = get_mask_pair( + mask_name, + seqlen_q=sq, + seqlen_k=sk, + window_size=window_size, + ) + return flex_mask + + else: + raise ValueError(f"Unknown mask: {mask_name}") + + # Run the kernel with varlen (packed format) + out = torch.empty_like(q) + softmax_scale = 1.0 / math.sqrt(head_dim) + + out_tuple = _flash_attn_fwd( + q=q, + k=k, + v=v, + out=out, + lse=None, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=softmax_scale, + causal=False, + softcap=None, + window_size_left=-1, + window_size_right=-1, + learnable_sink=None, + tile_mn=(128, 128), + pack_gqa=pack_gqa, + _arch=None, + score_mod=None, + mask_mod=cute_mask_mod, + block_sparse_tensors=None, + return_lse=True, + aux_tensors=aux_tensors_arg, + ) + out_cute = out_tuple[0] + + # Run per-sequence flex_attention references + out_ref_fp32 = run_flex_per_sequence( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + flex_factory, + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype=torch.float32, + ) + out_pt = run_flex_per_sequence( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + flex_factory, + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype=dtype, + ) + + # Check results + mask_desc = f"mask_mod={mask_name}" + if window_size is not None: + mask_desc += f"(w={window_size})" + test_name = ( + f"{mask_desc} varlen seqs_q={seqlens_q}, seqs_k={seqlens_k}, " + f"H={num_heads}/{num_kv_heads}, D={head_dim}" + ) + check_varlen_results( + out_cute, out_ref_fp32, out_pt, seqlens_q, cu_seqlens_q, test_name + ) + + +# ============================================================================= +# Test cases +# ============================================================================= + +# Masks that don't need recompilation per seqlen (fast) +STATIC_MASK_NAMES = ["block_diagonal", "mini_causal"] + +# Masks that need per-seqlen compilation (slower) +PARAMETERIZED_MASK_CONFIGS = [ + ("causal", None), + ("block_causal", None), + ("sliding_window", 128), + ("sliding_window", 256), + ("document", None), +] + + +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa"]) +@pytest.mark.parametrize("mask_name", STATIC_MASK_NAMES) +def test_varlen_static_masks(seqlens_q, seqlens_k, dtype, kv_mode, mask_name): + """Test static mask_mods with varlen (packed) attention.""" + num_heads = 8 + if kv_mode == "gqa": + if COMPUTE_CAPABILITY < 9: + pytest.xfail("pack_gqa requires SM90+") + num_kv_heads = 2 + else: + num_kv_heads = num_heads + + _run_varlen_mask_test( + seqlens_q=seqlens_q, + seqlens_k=seqlens_k, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=128, + dtype=dtype, + mask_name=mask_name, + ) + + +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa"]) +@pytest.mark.parametrize("mask_name,window_size", PARAMETERIZED_MASK_CONFIGS) +def test_varlen_parameterized_masks( + seqlens_q, seqlens_k, dtype, kv_mode, mask_name, window_size +): + """Test parameterized mask_mods with varlen (packed) attention. + + Uses fewer seqlen configs since these require recompilation per seqlen. + """ + num_heads = 8 + if kv_mode == "gqa": + if COMPUTE_CAPABILITY < 9: + pytest.xfail("pack_gqa requires SM90+") + num_kv_heads = 2 + else: + num_kv_heads = num_heads + + _run_varlen_mask_test( + seqlens_q=seqlens_q, + seqlens_k=seqlens_k, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=128, + dtype=dtype, + mask_name=mask_name, + window_size=window_size, + ) + + +# ============================================================================= +# Global-index mask test runner +# ============================================================================= + + +def _run_varlen_global_mask_test( + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype, + mask_name, +): + """Run a varlen global-index mask_mod test: kernel vs per-sequence flex_attention.""" + torch.manual_seed(42) + random.seed(42) + + pack_gqa = num_heads != num_kv_heads + + q, k, v, cu_seqlens_q, cu_seqlens_k = setup_varlen_tensors( + seqlens_q, seqlens_k, num_heads, num_kv_heads, head_dim, dtype + ) + + if mask_name == "global_packed_doc": + doc_ids_q, doc_ids_k = make_packed_doc_ids(seqlens_q, seqlens_k, device="cuda") + cute_mask = cute_global_packed_doc_mask + flex_fac = global_packed_doc_flex_factory( + doc_ids_q, doc_ids_k, cu_seqlens_q, cu_seqlens_k + ) + aux_tensors = [doc_ids_q, doc_ids_k] + elif mask_name == "global_ima": + thresholds = make_global_thresholds(seqlens_k, device="cuda") + cute_mask = cute_global_ima_mask + flex_fac = global_ima_flex_factory(thresholds, cu_seqlens_k) + aux_tensors = [thresholds] + elif mask_name == "global_causal_window": + windows = make_global_windows(seqlens_q, device="cuda") + cute_mask = cute_global_causal_window_mask + flex_fac = global_causal_window_flex_factory(windows, cu_seqlens_q) + aux_tensors = [windows] + else: + raise ValueError(f"Unknown global mask: {mask_name}") + + out = torch.empty_like(q) + softmax_scale = 1.0 / math.sqrt(head_dim) + + out_tuple = _flash_attn_fwd( + q=q, + k=k, + v=v, + out=out, + lse=None, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=softmax_scale, + causal=False, + softcap=None, + window_size_left=-1, + window_size_right=-1, + learnable_sink=None, + tile_mn=(128, 128), + pack_gqa=pack_gqa, + _arch=None, + score_mod=None, + mask_mod=cute_mask, + block_sparse_tensors=None, + return_lse=True, + aux_tensors=aux_tensors, + ) + out_cute = out_tuple[0] + + out_ref_fp32 = run_flex_per_sequence( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + flex_fac, + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype=torch.float32, + ) + out_pt = run_flex_per_sequence( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + flex_fac, + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype=dtype, + ) + + test_name = ( + f"global_mask={mask_name} varlen seqs_q={seqlens_q}, seqs_k={seqlens_k}, " + f"H={num_heads}/{num_kv_heads}, D={head_dim}" + ) + check_varlen_results( + out_cute, out_ref_fp32, out_pt, seqlens_q, cu_seqlens_q, test_name + ) + + +GLOBAL_MASK_NAMES = ["global_packed_doc", "global_ima", "global_causal_window"] + + +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("mask_name", GLOBAL_MASK_NAMES) +def test_varlen_global_masks(seqlens_q, seqlens_k, mask_name): + """Test global-index mask_mods (aux-tensor-driven) with varlen packed attention.""" + _run_varlen_global_mask_test( + seqlens_q, seqlens_k, 8, 8, 128, torch.bfloat16, mask_name + ) + + +# ============================================================================= +# Vectorized mask_mod equality tests (varlen) +# Pattern: scalar mask is reference; vec mask at multiple __vec_size__ values +# must produce bit-identical output. +# ============================================================================= + +# (mask_name, window_size, needs_aux) +VEC_MASK_TEST_CASES = [ + ("causal", None, False), + ("block_causal", None, False), + ("sliding_window", 128, False), + ("block_diagonal", None, False), + ("prefix_lm", None, False), + ("packed_aux", None, True), +] + +# Vectorized mask_mod application is currently implemented for SM100/SM110 forward. +# vec_size > 32 is only supported by packed_aux (other vec mods return shape-(1,) Uint32). +VEC_MASK_SIZES_TO_CHECK_EQUALITY = [2, 8, 32, 128] + + +def _run_varlen_mask_only( + q, k, v, cu_seqlens_q, cu_seqlens_k, mask_mod, aux_tensors, pack_gqa +): + out = torch.empty_like(q) + _flash_attn_fwd( + q=q, + k=k, + v=v, + out=out, + lse=None, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=1.0 / math.sqrt(q.shape[-1]), + causal=False, + softcap=None, + window_size_left=-1, + window_size_right=-1, + learnable_sink=None, + tile_mn=(128, 128), + pack_gqa=pack_gqa, + _arch=None, + score_mod=None, + mask_mod=mask_mod, + block_sparse_tensors=None, + return_lse=False, + aux_tensors=aux_tensors, + ) + return out + + +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS_SMOKE) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa"]) +@pytest.mark.parametrize("mask_case", VEC_MASK_TEST_CASES) +def test_varlen_mask_mod_vectorized(seqlens_q, seqlens_k, dtype, kv_mode, mask_case): + """Tests equality between scalar and vectorized mask mods on varlen inputs.""" + if COMPUTE_CAPABILITY not in (10, 11): + pytest.skip("vectorized mask_mod application is SM100/SM110-only") + mask_name, window_size, needs_aux = mask_case + + if mask_name == "block_causal": + offsets = [sk - sq for sq, sk in zip(seqlens_q, seqlens_k)] + if len(set(offsets)) > 1: + pytest.skip( + "block_causal captures offset as compile-time constant; " + "varlen with different per-sequence offsets not supported" + ) + if mask_name == "sliding_window": + for sq, sk in zip(seqlens_q, seqlens_k): + if sq > sk: + pytest.skip( + "sliding_window requires seqlen_q <= seqlen_k for each sequence" + ) + + torch.manual_seed(42) + num_heads = 8 + if kv_mode == "gqa": + if COMPUTE_CAPABILITY < 9: + pytest.xfail("pack_gqa requires SM90+") + num_kv_heads = 2 + else: + num_kv_heads = num_heads + pack_gqa = num_heads != num_kv_heads + head_dim = 128 + + q, k, v, cu_seqlens_q, cu_seqlens_k = setup_varlen_tensors( + seqlens_q, seqlens_k, num_heads, num_kv_heads, head_dim, dtype + ) + + batch_size = len(seqlens_q) + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + + if needs_aux: + aux_tensors = [ + make_packed_mask_aux_tensor(batch_size, max_seqlen_q, max_seqlen_k) + ] + scalar_mod = EXTRA_SCALAR_MASKS[mask_name] + else: + aux_tensors = None + scalar_mod, _ = get_mask_pair( + mask_name, + seqlen_q=max_seqlen_q, + seqlen_k=max_seqlen_k, + window_size=window_size, + ) + + out_ref = _run_varlen_mask_only( + q, k, v, cu_seqlens_q, cu_seqlens_k, scalar_mod, aux_tensors, pack_gqa + ) + + for vec_size in VEC_MASK_SIZES_TO_CHECK_EQUALITY: + if vec_size > 32 and mask_name != "packed_aux": + continue + vec_mod = get_vec_mask( + mask_name, + seqlen_q=max_seqlen_q, + seqlen_k=max_seqlen_k, + window_size=window_size, + vec_size=vec_size, + ) + if vec_mod is None: + pytest.skip(f"no vec mask for {mask_name}") + vec_mod.__vec_size__ = vec_size + out = _run_varlen_mask_only( + q, k, v, cu_seqlens_q, cu_seqlens_k, vec_mod, aux_tensors, pack_gqa + ) + assert torch.equal(out, out_ref), ( + f"{mask_name} vec_size={vec_size}: output mismatch vs scalar reference" + ) + + +# ============================================================================= +# Block sparsity end-to-end tests +# ============================================================================= + + +@cute.jit +def cute_all_true_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + return m_idx >= utils.scalar_to_ssa(0, cutlass.Int32) + + +def _make_block_sparse_tensors( + mask_mod, + seqlens_q, + seqlens_k, + num_heads, + tile_m, + tile_n, + device, + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_k=None, + aux_tensors=None, +): + """Compute block sparse tensors. cu_total_m_blocks / cu_block_idx_offsets are + populated on the returned BlockSparseTensorsTorch.""" + batch_size = len(seqlens_q) + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + + if cu_seqlens_q is not None: + cu_total_m_blocks_list = [0] + for batch_idx in range(batch_size): + num_m_blocks = (seqlens_q[batch_idx] + tile_m - 1) // tile_m + cu_total_m_blocks_list.append(cu_total_m_blocks_list[-1] + num_m_blocks) + cu_total_m_blocks = torch.tensor( + cu_total_m_blocks_list, dtype=torch.int32, device=device + ) + + cu_block_idx_offsets = None + if cu_seqlens_k is not None or seqused_k is not None: + cu_block_idx_offsets_list = [0] + for batch_idx in range(batch_size): + num_m_blocks = (seqlens_q[batch_idx] + tile_m - 1) // tile_m + num_n_blocks = (seqlens_k[batch_idx] + tile_n - 1) // tile_n + cu_block_idx_offsets_list.append( + cu_block_idx_offsets_list[-1] + num_m_blocks * num_n_blocks + ) + cu_block_idx_offsets = torch.tensor( + cu_block_idx_offsets_list, dtype=torch.int32, device=device + ) + else: + cu_total_m_blocks = None + cu_block_idx_offsets = None + + block_sparse_tensors = compute_block_sparsity( + tile_m=tile_m, + tile_n=tile_n, + batch_size=batch_size, + num_heads=num_heads, + seqlen_q=max_seqlen_q, + seqlen_k=max_seqlen_k, + mask_mod=mask_mod, + aux_tensors=aux_tensors, + device=device, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + cu_total_m_blocks=cu_total_m_blocks, + cu_block_idx_offsets=cu_block_idx_offsets, + seqused_k=seqused_k, + ) + return block_sparse_tensors + + +def _run_fwd( + q, + k, + v, + mask_mod, + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_k=None, + block_sparse_tensors=None, + aux_tensors=None, + num_splits=1, +): + out = torch.empty_like(q) + return _flash_attn_fwd( + q=q, + k=k, + v=v, + out=out, + lse=None, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=None, + seqused_k=seqused_k, + page_table=None, + softmax_scale=1.0 / math.sqrt(q.shape[-1]), + causal=False, + softcap=None, + window_size_left=-1, + window_size_right=-1, + learnable_sink=None, + tile_mn=(128, 128), + pack_gqa=False, + _arch=None, + score_mod=None, + mask_mod=mask_mod, + block_sparse_tensors=block_sparse_tensors, + return_lse=False, + aux_tensors=aux_tensors, + num_splits=num_splits, + )[0] + + +BLOCK_SPARSE_MASK_NAMES = [ + "causal", + "block_diagonal", + "mini_causal", + "prefix_lm", + "sliding_window", + "dilated_sliding_window", + "document", + "ima", +] + +BLOCK_SPARSE_SEQLEN_CONFIGS = [ + ([128, 192, 256], [128, 192, 256]), + ([64, 128], [64, 128]), + ([256, 512, 256], [256, 512, 256]), + ([128, 192, 256], [64, 128, 192]), +] + + +# @pytest.mark.parametrize("varlen_q", [False, True]) +@pytest.mark.parametrize("seqlens", BLOCK_SPARSE_SEQLEN_CONFIGS) +@pytest.mark.parametrize("mask_name", BLOCK_SPARSE_MASK_NAMES) +@pytest.mark.parametrize("varlen_q", [False, True]) +@pytest.mark.parametrize("varlen_k", [False, True]) +@pytest.mark.parametrize("use_seqused_k", [False, True]) +@pytest.mark.parametrize("head_broadcast", [False, True]) +def test_varlen_block_sparse( + varlen_q, varlen_k, use_seqused_k, head_broadcast, mask_name, seqlens +): + """Block sparsity + mask_mod should produce identical output to mask_mod alone.""" + if varlen_k and use_seqused_k: + pytest.skip("packed K (cu_seqlens_k) and seqused_k are mutually exclusive") + if not varlen_q and varlen_k: + pytest.skip( + "block sparsity with padded Q + packed K requires per-batch n-block offsets; not yet supported" + ) + + torch.manual_seed(42) + random.seed(42) + device = "cuda" + num_heads = 4 + head_dim = 128 + dtype = torch.bfloat16 + # On Blackwell (SM100) q_stage=2 → effective tile is 256; elsewhere 128. + tile_m = 256 if COMPUTE_CAPABILITY >= 10 else 128 + tile_n = 128 + + base_seqlens_q, base_seqlens_k = seqlens + batch_size = len(base_seqlens_q) + max_seqlen_q = max(base_seqlens_q) + max_seqlen_k = max(base_seqlens_k) + + seqlens_q = base_seqlens_q if varlen_q else [max_seqlen_q] * batch_size + seqlens_k = ( + base_seqlens_k if (varlen_k or use_seqused_k) else [max_seqlen_k] * batch_size + ) + + def make_cu_seqlens(seqlens): + return torch.tensor( + [0] + list(torch.tensor(seqlens).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + + if varlen_q: + q = torch.randn(sum(seqlens_q), num_heads, head_dim, device=device, dtype=dtype) + cu_seqlens_q = make_cu_seqlens(seqlens_q) + else: + q = torch.randn( + batch_size, max_seqlen_q, num_heads, head_dim, device=device, dtype=dtype + ) + cu_seqlens_q = None + + if varlen_k: + k = torch.randn(sum(seqlens_k), num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(sum(seqlens_k), num_heads, head_dim, device=device, dtype=dtype) + cu_seqlens_k = make_cu_seqlens(seqlens_k) + seqused_k = None + else: + k = torch.randn( + batch_size, max_seqlen_k, num_heads, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, max_seqlen_k, num_heads, head_dim, device=device, dtype=dtype + ) + cu_seqlens_k = None + seqused_k = ( + torch.tensor(seqlens_k, dtype=torch.int32, device=device) + if use_seqused_k + else None + ) + + # Build mask_mod and aux_tensors for the requested mask + aux_tensors = None + if mask_name == "causal": + mask_mod, _ = get_mask_pair( + "causal", seqlen_q=max_seqlen_q, seqlen_k=max_seqlen_k + ) + elif mask_name == "sliding_window": + mask_mod, _ = get_mask_pair( + "sliding_window", + seqlen_q=max_seqlen_q, + seqlen_k=max_seqlen_k, + window_size=128, + ) + elif mask_name == "document": + max_doc_len = max(max_seqlen_q, max_seqlen_k) + doc_ids = random_doc_id_tensor( + num_heads, batch_size, max_doc_len, device=device + ) + doc_ids.__leading_dim__ = 2 + aux_tensors = [doc_ids] + mask_mod = get_mask_pair("document")[0] + elif mask_name == "ima": + bias = torch.randint( + 0, + max(1, max_seqlen_k // 2), + (max_seqlen_k,), + dtype=torch.int32, + device=device, + ) + aux_tensors = [bias] + mask_mod = get_mask_pair("ima")[0] + else: + mask_mod = get_mask_pair(mask_name)[0] + + num_heads_sparse = 1 if head_broadcast else num_heads + block_sparse_tensors = _make_block_sparse_tensors( + mask_mod=mask_mod, + seqlens_q=seqlens_q, + seqlens_k=seqlens_k, + num_heads=num_heads_sparse, + tile_m=tile_m, + tile_n=tile_n, + device=device, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + aux_tensors=aux_tensors, + ) + + out_with_block_sparsity = _run_fwd( + q, + k, + v, + mask_mod, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + block_sparse_tensors=block_sparse_tensors, + aux_tensors=aux_tensors, + ) + out_no_block_sparsity = _run_fwd( + q, + k, + v, + mask_mod, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + aux_tensors=aux_tensors, + ) + + assert not torch.isnan(out_with_block_sparsity).any(), "NaN in block-sparse output" + max_err = (out_with_block_sparsity - out_no_block_sparsity).abs().max().item() + assert max_err <= 0.01, ( + f"block-sparse output differs from mask-mod-only by {max_err}" + ) + + +VARLEN_BLOCK_SPARSE_SPLITKV_SEQLENS = [ + ([128], [2048]), + ([96], [1536]), + ([128, 64], [2048, 1024]), + ([1, 128], [256, 2048]), +] + +VARLEN_BLOCK_SPARSE_SPLITKV_MASKS = [ + "all_true", + "causal", + "sliding_window", + "prefix_lm", + "block_diagonal", +] + + +def _get_splitkv_varlen_mask(mask_name, max_seqlen_q, max_seqlen_k): + match mask_name: + case "all_true": + return cute_all_true_mask + case "causal": + return get_mask_pair("causal", seqlen_q=max_seqlen_q, seqlen_k=max_seqlen_k)[0] + case "sliding_window": + return get_mask_pair( + "sliding_window", + seqlen_q=max_seqlen_q, + seqlen_k=max_seqlen_k, + window_size=512, + )[0] + case _: + return get_mask_pair(mask_name)[0] + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="SM100/SM110 SplitKV forward only") +@pytest.mark.parametrize("use_seqused_k", [False, True]) +@pytest.mark.parametrize("mask_name", VARLEN_BLOCK_SPARSE_SPLITKV_MASKS) +@pytest.mark.parametrize("seqlens_q,seqlens_k", VARLEN_BLOCK_SPARSE_SPLITKV_SEQLENS) +def test_varlen_block_sparse_splitkv_matches_unsplit(seqlens_q, seqlens_k, mask_name, use_seqused_k): + """Varlen block-sparse SplitKV should match the unsplit block-sparse path.""" + torch.manual_seed(123) + random.seed(123) + device = "cuda" + num_heads = 4 + head_dim = 128 + dtype = torch.bfloat16 + sparse_tile_m = 256 if sum(seqlens_q) > 128 else 128 + tile_n = 128 + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + + q = torch.randn(sum(seqlens_q), num_heads, head_dim, device=device, dtype=dtype) + cu_seqlens_q = torch.tensor( + [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + mask_mod = _get_splitkv_varlen_mask(mask_name, max_seqlen_q, max_seqlen_k) + + if use_seqused_k: + k = torch.randn( + len(seqlens_k), max_seqlen_k, num_heads, head_dim, device=device, dtype=dtype + ) + v = torch.randn_like(k) + cu_seqlens_k = None + seqused_k = torch.tensor(seqlens_k, dtype=torch.int32, device=device) + else: + k = torch.randn(sum(seqlens_k), num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn_like(k) + cu_seqlens_k = torch.tensor( + [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + seqused_k = None + + block_sparse_tensors = _make_block_sparse_tensors( + mask_mod=mask_mod, + seqlens_q=seqlens_q, + seqlens_k=seqlens_k, + num_heads=1, + tile_m=sparse_tile_m, + tile_n=tile_n, + device=device, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + ) + + out_unsplit = _run_fwd( + q, + k, + v, + mask_mod, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + block_sparse_tensors=block_sparse_tensors, + ) + out_split = _run_fwd( + q, + k, + v, + mask_mod, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + block_sparse_tensors=block_sparse_tensors, + num_splits=3, + ) + out_no_block_sparsity = _run_fwd( + q, + k, + v, + mask_mod, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + ) + + assert not torch.isnan(out_split).any(), "NaN in SplitKV block-sparse output" + assert torch.isfinite(out_split).all(), "Inf in SplitKV block-sparse output" + assert (out_unsplit - out_no_block_sparsity).abs().max().item() <= 0.01 + assert (out_split - out_no_block_sparsity).abs().max().item() <= 0.01 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/cute/test_utils.py b/tests/cute/test_utils.py index 189eb86957d..4ec077933cc 100644 --- a/tests/cute/test_utils.py +++ b/tests/cute/test_utils.py @@ -160,6 +160,19 @@ def tracking_sha256(*args, **kwargs): assert call_tracker["sha256"] == 0, "sha256 should not be called" assert result == "wrapped-fast-hash" + def test_vec_size_affects_hash(self): + def mask_mod(_b, _h, q_idx, kv_idx): + return q_idx >= kv_idx + + base_hash = hash_callable(mask_mod, set_cute_hash=False) + mask_mod.__vec_size__ = 16 + vec16_hash = hash_callable(mask_mod, set_cute_hash=False) + mask_mod.__vec_size__ = 32 + vec32_hash = hash_callable(mask_mod, set_cute_hash=False) + + assert base_hash != vec16_hash + assert vec16_hash != vec32_hash + def test_closure_values_affect_hash(self): """Functions with different closure values should have different hashes.""" value1 = 10 diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 77e4582ee37..c38dd02ddff 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -2545,7 +2545,7 @@ def test_flash_attn_paged_kvcache_overflow( paged_kv_block_size, causal, dtype, -): +): device = "cuda" num_blocks = 1000*16//paged_kv_block_size key_cache = torch.rand([num_blocks, paged_kv_block_size, nheads, d], dtype=dtype, device=device) @@ -2567,3 +2567,61 @@ def test_flash_attn_paged_kvcache_overflow( block_table=block_tables, causal=causal, ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_flash_attn_varlen_paged_kv_num_splits(dtype): + """Passing num_splits=0 explicitly should be bitwise identical to not passing it (default).""" + from flash_attn.flash_attn_interface import _flash_attn_varlen_forward + + device = "cuda" + num_heads, num_heads_k, head_dim = 4, 2, 64 + page_block_size = 256 + scale = head_dim ** -0.5 + + batch_size = 2 + kv_lens = [512, 1024] + max_seqlen_k = max(kv_lens) + + max_blocks_per_seq = max( + (s + page_block_size - 1) // page_block_size for s in kv_lens + ) + total_blocks = batch_size * max_blocks_per_seq + k_cache = torch.randn( + total_blocks, page_block_size, num_heads_k, head_dim, + device=device, dtype=dtype, + ) + v_cache = torch.randn( + total_blocks, page_block_size, num_heads_k, head_dim, + device=device, dtype=dtype, + ) + + block_table = rearrange( + torch.randperm(total_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + + q = torch.randn(batch_size, num_heads, head_dim, device=device, dtype=dtype) + cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=device) + seqused_k = torch.tensor(kv_lens, dtype=torch.int32, device=device) + cu_seqlens_k = torch.nn.functional.pad(seqused_k.cumsum(0), (1, 0)).to(torch.int32) + + fwd_kwargs = dict( + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=1, max_seqlen_k=max_seqlen_k, + dropout_p=0.0, softmax_scale=scale, + causal=False, window_size_left=-1, window_size_right=0, + block_table=block_table, seqused_k=seqused_k, + ) + + out_default = _flash_attn_varlen_forward(q, k_cache, v_cache, **fwd_kwargs)[0] + out_explicit = _flash_attn_varlen_forward(q, k_cache, v_cache, **fwd_kwargs, num_splits=0)[0] + + assert not out_default.isnan().any(), "default num_splits produced NaN" + assert torch.equal(out_default, out_explicit), ( + f"default vs num_splits=0 differ: max diff {(out_default - out_explicit).abs().max().item()}" + ) + + with pytest.raises(RuntimeError, match="num_splits > 1 is not supported"): + _flash_attn_varlen_forward(q, k_cache, v_cache, **fwd_kwargs, num_splits=2) diff --git a/third_party/aiter b/third_party/aiter index b4b75165fbd..3b2e6f48ce9 160000 --- a/third_party/aiter +++ b/third_party/aiter @@ -1 +1 @@ -Subproject commit b4b75165fbd2456dfd0f074c5b2ef91bc87d97e5 +Subproject commit 3b2e6f48ce97e1d494e8b3f1af5c65f74e304b28