diff --git a/flash_attn/cute/benchmark_mask_mod.py b/flash_attn/cute/benchmark_mask_mod.py index b1aadd8939..9b7950ba07 100644 --- a/flash_attn/cute/benchmark_mask_mod.py +++ b/flash_attn/cute/benchmark_mask_mod.py @@ -21,7 +21,11 @@ create_cute_sliding_window_mask, create_flex_sliding_window_mask, ) -from block_sparsity import compute_block_sparsity +from flash_attn.cute.block_sparsity import ( + compute_block_sparsity, + BlockSparseTensorsTorch, + to_cute_block_sparse_tensors, +) @dataclass @@ -265,10 +269,12 @@ def _create_tensors(self) -> Dict[str, torch.Tensor]: ) if all(t is not None for t in [full_cnt, full_idx, mask_cnt, mask_idx]): - tensors["full_block_cnt"] = full_cnt.contiguous() - tensors["full_block_idx"] = full_idx.contiguous() - tensors["mask_block_cnt"] = mask_cnt.contiguous() - tensors["mask_block_idx"] = mask_idx.contiguous() + tensors["block_sparse_tensors"] = BlockSparseTensorsTorch( + mask_block_cnt=mask_cnt.contiguous(), + mask_block_idx=mask_idx.contiguous(), + full_block_cnt=full_cnt.contiguous(), + full_block_idx=full_idx.contiguous(), + ) if config.verbose: total_full = full_cnt.sum().item() @@ -373,33 +379,9 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] else None ) - # Block sparsity tensors - full_block_cnt_cute = ( - from_dlpack(tensors["full_block_cnt"].detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=2 - ) - if "full_block_cnt" in tensors - else None - ) - full_block_idx_cute = ( - from_dlpack(tensors["full_block_idx"].detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=3 - ) - if "full_block_idx" in tensors - else None - ) - mask_block_cnt_cute = ( - from_dlpack(tensors["mask_block_cnt"].detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=2 - ) - if "mask_block_cnt" in tensors - else None - ) - mask_block_idx_cute = ( - from_dlpack(tensors["mask_block_idx"].detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=3 - ) - if "mask_block_idx" in tensors + blocksparse_tensors_cute = ( + to_cute_block_sparse_tensors(tensors["block_sparse_tensors"]) + if "block_sparse_tensors" in tensors else None ) @@ -436,11 +418,8 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] None, # page_table window_left_cute, window_right_cute, - learnable_sink_cute, # learnable_sink - full_block_cnt_cute, - full_block_idx_cute, - mask_block_cnt_cute, - mask_block_idx_cute, + learnable_sink_cute, + blocksparse_tensors_cute, aux_tensors_cute, # None, ) @@ -461,10 +440,7 @@ def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple] window_left_cute, window_right_cute, learnable_sink_cute, - full_block_cnt_cute, - full_block_idx_cute, - mask_block_cnt_cute, - mask_block_idx_cute, + blocksparse_tensors_cute, aux_tensors_cute, # None, ) diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index be685dea5d..c28df4c20d 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -8,13 +8,92 @@ by a more robust preprocessing kernel in the future. """ -from typing import Tuple, Optional, Callable, List +from typing import Tuple, Optional, Callable, List, NamedTuple import torch +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack # placeholder Config = type("Config", (), {}) +class BlockSparseTensors(NamedTuple): + mask_block_cnt: cute.Tensor + mask_block_idx: cute.Tensor + full_block_cnt: Optional[cute.Tensor] + full_block_idx: Optional[cute.Tensor] + + def __new_from_mlir_values__(self, values): + return BlockSparseTensors(*values) + + +class BlockSparseTensorsTorch(NamedTuple): + mask_block_cnt: torch.Tensor + mask_block_idx: torch.Tensor + full_block_cnt: Optional[torch.Tensor] = None + full_block_idx: Optional[torch.Tensor] = None + + +def validate_block_sparse_tensors(tensors: BlockSparseTensorsTorch) -> None: + for name, cnt, idx in ( + ("mask", tensors.mask_block_cnt, tensors.mask_block_idx), + ("full", tensors.full_block_cnt, tensors.full_block_idx), + ): + if (cnt is None) != (idx is None): + raise ValueError( + f"{name}_block_cnt and {name}_block_idx must both be provided or both be None" + ) + if cnt is None: + continue + if cnt.dtype != torch.int32 or idx.dtype != torch.int32: + raise ValueError(f"{name}_block tensors must have dtype torch.int32") + if cnt.device != idx.device: + raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device") + if not cnt.is_cuda or not idx.is_cuda: + raise ValueError(f"{name}_block tensors must live on CUDA") + + if tensors.full_block_cnt is not None and tensors.mask_block_cnt is not None: + if tensors.full_block_cnt.device != tensors.mask_block_cnt.device: + raise ValueError("All block sparse tensors must be on the same device") + + +def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool: + return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt)) + + +def to_cute_block_sparse_tensors(tensors: BlockSparseTensorsTorch) -> Optional[BlockSparseTensors]: + if not is_block_sparsity_enabled(tensors): + return None + + mask_block_cnt_tensor = from_dlpack( + tensors.mask_block_cnt.detach(), assumed_align=4 + ).mark_layout_dynamic(leading_dim=2) + mask_block_idx_tensor = from_dlpack( + tensors.mask_block_idx.detach(), assumed_align=4 + ).mark_layout_dynamic(leading_dim=3) + full_block_cnt_tensor = ( + from_dlpack(tensors.full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + if tensors.full_block_cnt is not None + else None + ) + full_block_idx_tensor = ( + from_dlpack(tensors.full_block_idx.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=3 + ) + if tensors.full_block_idx is not None + else None + ) + + return BlockSparseTensors( + mask_block_cnt_tensor, + mask_block_idx_tensor, + full_block_cnt_tensor, + full_block_idx_tensor, + ) + + def compute_block_sparsity( config: Config, mask_mod_flex: Optional[Callable], diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index b49a693dfc..16d57991f9 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -29,6 +29,7 @@ from flash_attn.cute.softmax import Softmax, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute import pipeline from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute.named_barrier import NamedBarrierFwd @@ -1271,10 +1272,7 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, - full_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) - full_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) - mask_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) - mask_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) + blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, ): """Configures and launches the flash attention kernel. @@ -1290,6 +1288,7 @@ def __call__( ) ) + # Assume all strides are divisible by 128 bits except the last stride new_stride = lambda t: ( *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), @@ -1325,9 +1324,8 @@ def __call__( ) # self.num_mma_regs = 232 # self.num_producer_regs = 40 - self.use_block_sparsity = const_expr( - mask_block_cnt is not None and full_block_cnt is not None - ) + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + self.use_scheduler_barrier = ( (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) if const_expr(self.intra_wg_overlap) @@ -1521,10 +1519,7 @@ def __call__( window_size_left, window_size_right, learnable_sink, - full_block_cnt, - full_block_idx, - mask_block_cnt, - mask_block_idx, + blocksparse_tensors, self.sQ_layout, self.sK_layout, self.sV_layout, @@ -1571,10 +1566,7 @@ def kernel( window_size_left: Optional[Int32], window_size_right: Optional[Int32], learnable_sink: Optional[cute.Tensor], - full_block_cnt: Optional[cute.Tensor], - full_block_idx: Optional[cute.Tensor], - mask_block_cnt: Optional[cute.Tensor], - mask_block_idx: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, @@ -1698,10 +1690,7 @@ def kernel( pipeline_k, pipeline_v, mbar_ptr_Q, - full_block_cnt, - full_block_idx, - mask_block_cnt, - mask_block_idx, + blocksparse_tensors, block_info, SeqlenInfoCls, TileSchedulerCls, @@ -1740,10 +1729,7 @@ def kernel( SeqlenInfoCls, AttentionMaskCls, TileSchedulerCls, - full_block_cnt, - full_block_idx, - mask_block_cnt, - mask_block_idx, + blocksparse_tensors, aux_tensors, fastdiv_mods, ) @@ -1763,10 +1749,7 @@ def load( pipeline_k: cutlass.pipeline.PipelineAsync, pipeline_v: cutlass.pipeline.PipelineAsync, mbar_ptr_Q: cutlass.Pointer, - full_block_cnt: Optional[cute.Tensor], - full_block_idx: Optional[cute.Tensor], - mask_block_cnt: Optional[cute.Tensor], - mask_block_idx: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, @@ -1852,6 +1835,7 @@ def load( # ========================================== # Flex Attention blocksparsity # ========================================== + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] @@ -2033,10 +2017,7 @@ def mma( SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, - full_block_cnt: Optional[cute.Tensor], - full_block_idx: Optional[cute.Tensor], - mask_block_cnt: Optional[cute.Tensor], - mask_block_idx: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], aux_tensors: Optional[list], fastdiv_mods=None, ): @@ -2263,6 +2244,7 @@ def mma( # ========================================== # Block sparsity # ========================================== + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 0758d3f405..ec8582f6c1 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -33,6 +33,7 @@ from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils @@ -223,10 +224,7 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, - full_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) - full_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) - mask_block_cnt: Optional[cute.Tensor] = None, # (b, h, m_block) - mask_block_idx: Optional[cute.Tensor] = None, # (b, h, m_block, n_block) + blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -242,7 +240,6 @@ def __call__( 5. Grid and work scheduling computation 6. Kernel launch with appropriate parameters """ - # setup static attributes before smem/grid/tma computation self.q_dtype = mQ.element_type self.k_dtype = mK.element_type diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index b77a70d921..6db43f36cf 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -41,6 +41,8 @@ from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine +from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch, to_cute_block_sparse_tensors + def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -78,10 +80,7 @@ def _flash_attn_fwd( _compute_capability: Optional[int] = None, score_mod: Optional[Callable] = None, mask_mod: Optional[Callable] = None, - full_block_cnt: Optional[torch.Tensor] = None, - full_block_idx: Optional[torch.Tensor] = None, - mask_block_cnt: Optional[torch.Tensor] = None, - mask_block_idx: Optional[torch.Tensor] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, @@ -155,10 +154,7 @@ def _flash_attn_fwd( if learnable_sink is not None: assert learnable_sink.shape == (num_head,) assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" - for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]: - if t is not None: - assert t.dtype == torch.int32, "blocksparse mask tensors must be int32" - # assert t.stride(0) == 1, "blocksparse mask tensors must be contiguous" + assert all( t is None or t.is_cuda for t in ( @@ -171,10 +167,6 @@ def _flash_attn_fwd( seqused_k, page_table, learnable_sink, - full_block_cnt, - full_block_idx, - mask_block_cnt, - mask_block_idx, ) ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" @@ -258,28 +250,13 @@ def _flash_attn_fwd( if page_table is not None else None ) - - full_block_cnt_tensor = ( - from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) - if full_block_cnt is not None + sparse_tensors = ( + to_cute_block_sparse_tensors(block_sparse_tensors) + if block_sparse_tensors is not None else None ) - full_block_idx_tensor = ( - from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) - if full_block_idx is not None - else None - ) - mask_block_cnt_tensor = ( - from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) - if mask_block_cnt is not None - else None - ) - mask_block_idx_tensor = ( - from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) - if mask_block_idx is not None - else None - ) - use_block_sparsity = full_block_cnt is not None or mask_block_cnt is not None + + use_block_sparsity = sparse_tensors is not None if mask_mod is None: if causal: @@ -415,6 +392,8 @@ def _flash_attn_fwd( assert page_size in [None, 128], ( "Only page_size=128 is supported for paged KV on SM 10.0" ) + if sparse_tensors is not None: + raise NotImplementedError("BlockSparsity not yet supported on SM 10.0") fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, @@ -451,10 +430,7 @@ def _flash_attn_fwd( window_size_left, window_size_right, learnable_sink_tensor, - full_block_cnt_tensor, - full_block_idx_tensor, - mask_block_cnt_tensor, - mask_block_idx_tensor, + sparse_tensors, cute_aux_tensors, ) _flash_attn_fwd.compile_cache[compile_key]( @@ -473,10 +449,7 @@ def _flash_attn_fwd( window_size_left, window_size_right, learnable_sink_tensor, - full_block_cnt_tensor, - full_block_idx_tensor, - mask_block_cnt_tensor, - mask_block_idx_tensor, + sparse_tensors, cute_aux_tensors, ) return out, lse diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index ce3a28b82c..033d08f296 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from flash_attn.cute.interface import _flash_attn_fwd -from flash_attn.cute.block_sparsity import compute_block_sparsity +from flash_attn.cute.block_sparsity import compute_block_sparsity, BlockSparseTensorsTorch from flash_attn.cute.mask_definitions import ( MASK_FUNCTIONS, flex_causal_mask, @@ -304,6 +304,14 @@ class Config: # print(f" First Q block - full indices: {full_idx[0,0,0,:full_cnt[0,0,0].item()]}") # if mask_cnt[0,0,0] > 0: # print(f" First Q block - mask indices: {mask_idx[0,0,0,:mask_cnt[0,0,0].item()]}") + block_sparse_mask = None + if use_mask_mod: + block_sparse_mask = BlockSparseTensorsTorch( + mask_block_cnt=mask_cnt, + mask_block_idx=mask_idx, + full_block_cnt=full_cnt, + full_block_idx=full_idx, + ) out_tuple = _flash_attn_fwd( q=tensors["q"], @@ -329,10 +337,7 @@ class Config: _compute_capability=None, score_mod=None, mask_mod=mask_mod_cute, - full_block_cnt=full_cnt, - full_block_idx=full_idx, - mask_block_cnt=mask_cnt, - mask_block_idx=mask_idx, + block_sparse_tensors=block_sparse_mask, return_lse=True, aux_tensors=None, )