From 495ef79454d54f37eae6067904079d4ac03776b2 Mon Sep 17 00:00:00 2001 From: apophis Date: Thu, 7 May 2026 02:23:22 +0800 Subject: [PATCH 01/21] [ROCm Windows] fix build failed (#2519) * [ROCm Windows] fix triton requirement * pin triton-windows>=3.6.0 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a8f912f7862..50f4b2fc79e 100644 --- a/setup.py +++ b/setup.py @@ -673,7 +673,7 @@ def spawn(cmd): # Note: torch is excluded because pip resolves it to CUDA PyTorch from PyPI, overwriting any pre-installed ROCm PyTorch. Users must have torch installed. install_requires = [ "einops", - "triton==3.5.1", + "triton==3.5.1" if sys.platform != "win32" else "triton-windows>=3.6.0", ] else: install_requires = [ From c263382a091cc0545afd2169a873328243f792c8 Mon Sep 17 00:00:00 2001 From: Reuben Stern <107093092+reubenconducts@users.noreply.github.com> Date: Wed, 6 May 2026 19:19:21 -0400 Subject: [PATCH 02/21] don't disable 2cta due to cuda 12 in bwd (#2543) --- flash_attn/cute/interface.py | 2 +- flash_attn/cute/utils.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index d12ea7b80b4..9e9899d590b 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -479,7 +479,7 @@ def _flash_attn_fwd( ) requested_use_clc_scheduler = utils._get_use_clc_scheduler_default() - requested_disable_2cta = utils._get_disable_2cta_default() + requested_disable_2cta = utils._get_disable_2cta_default(is_fwd=True) current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index b9dc4f5c112..ffabd34e398 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -86,8 +86,11 @@ def _get_use_clc_scheduler_default() -> bool: return _fa_clc_enabled -def _get_disable_2cta_default() -> bool: - return _fa_disable_2cta_enabled or _fa_disable_2cta_cuda12 +def _get_disable_2cta_default(is_fwd: bool = False) -> bool: + if is_fwd: + return _fa_disable_2cta_enabled or _fa_disable_2cta_cuda12 + else: + return _fa_disable_2cta_enabled def _compute_base_hash(func: Callable) -> str: From 9192248e4a61b75629258d151ea45ad0e0161ff5 Mon Sep 17 00:00:00 2001 From: Reuben Stern <107093092+reubenconducts@users.noreply.github.com> Date: Wed, 6 May 2026 23:32:02 -0400 Subject: [PATCH 03/21] [CuTe,Bwd] guard softcap for varlen backward (#2544) --- flash_attn/cute/interface.py | 2 +- tests/cute/test_flash_attn.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 9e9899d590b..bcfc406b023 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1419,7 +1419,7 @@ def _flash_attn_bwd( ) score_mod = utils.create_softcap_scoremod(softcap) score_mod_bwd = utils.create_softcap_scoremod_bwd(softcap) - elif score_mod is not None: + if score_mod is not None: assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided" assert cu_seqlens_q is None and cu_seqlens_k is None, ( "varlen + score_mod not supported in bwd yet" diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index f96412dd7c7..21ed3a48d57 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -842,6 +842,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): or (IS_SM100 and d == 256 and dv == 256) ) and not has_learnable_sink + and softcap == 0.0 # TODO: support softcap != 0.0 in varlen bwd # and False ): if d > 192 and IS_SM90: From 25b451eca0c55313c76e04b9c35c39943ff0d8c1 Mon Sep 17 00:00:00 2001 From: Reuben Stern <107093092+reubenconducts@users.noreply.github.com> Date: Wed, 6 May 2026 23:55:19 -0400 Subject: [PATCH 04/21] [CuTe,Flex] varlen blocksparsity (#2224) * varlen block-sparsity for forward Squashed forward-path varlen support: extends BlockSparseTensors usage to [num_heads, total_m_blocks] / [num_heads, total_n_blocks] layouts, threads cu_seqlens / cu_total_m_blocks / cu_total_n_blocks through the kernel and compute_block_sparsity, and routes through get_curr_blocksparse_tensors and get_total_block_count for shape-aware indexing. * rename cu_total_n_blocks to cu_block_idx_offsets; move cu_total_m_blocks/cu_block_idx_offsets into BlockSparseTensors instead of threading them as standalone parameters; drop the two [-1].item() syncs in normalize_block_sparse_config --- flash_attn/cute/block_sparse_utils.py | 196 +++-- flash_attn/cute/block_sparsity.py | 88 ++- flash_attn/cute/compute_block_sparsity.py | 472 ++++++++---- flash_attn/cute/flash_fwd_sm100.py | 15 + flash_attn/cute/flash_fwd_sm90.py | 6 + flash_attn/cute/interface.py | 48 +- flash_attn/cute/seqlen_info.py | 15 + tests/cute/benchmark_block_sparsity.py | 6 +- tests/cute/mask_mod_definitions.py | 231 +++++- tests/cute/test_block_sparsity.py | 379 ++++++++- tests/cute/test_mask_mod.py | 7 +- tests/cute/test_mask_mod_varlen.py | 895 ++++++++++++++++++++++ 12 files changed, 2090 insertions(+), 268 deletions(-) create mode 100644 tests/cute/test_mask_mod_varlen.py diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index d664b16dc64..def4f088d92 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -5,7 +5,7 @@ These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads. """ -from typing import Callable, Optional +from typing import Callable, Optional, Tuple from functools import partial import math import cutlass @@ -17,6 +17,67 @@ # Import data structures from block_sparsity from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.named_barrier import NamedBarrierBwd +from flash_attn.cute.seqlen_info import SeqlenInfoQK + + +@cute.jit +def _get_curr_blocksparse_tensors_varlen( + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + blocksparse_tensors: BlockSparseTensors, + seqlen_info: SeqlenInfoQK, +) -> Tuple[cutlass.Int32, cute.Tensor, cutlass.Int32, Optional[cute.Tensor]]: + """Varlen path: tensors are 2D [nheads, total_m_blocks] / [nheads, total_n_blocks].""" + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors + curr_m_block = seqlen_info.m_block_offset + m_block + curr_block_idx_offset = seqlen_info.block_idx_offset + m_block * seqlen_info.num_n_blocks + curr_mask_block_cnt = mask_block_cnt[head_idx, curr_m_block] + curr_mask_block_idx = cute.domain_offset(curr_block_idx_offset, mask_block_idx[head_idx, None]) + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[head_idx, curr_m_block] + curr_full_block_idx = cute.domain_offset( + curr_block_idx_offset, full_block_idx[head_idx, None] + ) + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + return (curr_mask_block_cnt, curr_mask_block_idx, curr_full_block_cnt, curr_full_block_idx) + + +@cute.jit +def _get_curr_blocksparse_tensors( + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + blocksparse_tensors: BlockSparseTensors, +) -> Tuple[cutlass.Int32, cute.Tensor, cutlass.Int32, Optional[cute.Tensor]]: + """Fixed-length path: tensors are 4D [batch, nheads, m_block, n_block].""" + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + return (curr_mask_block_cnt, curr_mask_block_idx, curr_full_block_cnt, curr_full_block_idx) + + +@cute.jit +def get_curr_blocksparse_tensors( + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + blocksparse_tensors: BlockSparseTensors, + seqlen_info: SeqlenInfoQK, +) -> Tuple[cutlass.Int32, cute.Tensor, cutlass.Int32, Optional[cute.Tensor]]: + """Extract head, m_block, and batch-local blocksparsity data from blocksparse_tensors""" + if const_expr(len(blocksparse_tensors.mask_block_cnt.shape) == 2): + return _get_curr_blocksparse_tensors_varlen( + head_idx, m_block, blocksparse_tensors, seqlen_info + ) + return _get_curr_blocksparse_tensors(batch_idx, head_idx, m_block, blocksparse_tensors) # NOTE [SM100 block-sparse empty tiles: mbarrier contract] @@ -162,6 +223,7 @@ def produce_block_sparse_loads( batch_idx, head_idx, m_block, + seqlen_info: SeqlenInfoQK, kv_producer_state, load_K, load_V, @@ -187,20 +249,20 @@ def produce_block_sparse_loads( qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and must be converted to unpacked for sparse tensor indexing. """ - - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors - m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] - - if const_expr(full_block_cnt is not None): - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] - else: - curr_full_block_cnt = Int32(0) - curr_full_block_idx = None + ( + curr_mask_block_cnt, + curr_mask_block_idx, + curr_full_block_cnt, + curr_full_block_idx, + ) = get_curr_blocksparse_tensors( + batch_idx, + head_idx, + m_block_sparse, + blocksparse_tensors, + seqlen_info, + ) mask_empty = curr_mask_block_cnt == 0 full_empty = curr_full_block_cnt == 0 @@ -305,7 +367,7 @@ def consume_block_sparse_loads( batch_idx, head_idx, m_block, - seqlen, + seqlen_info, kv_consumer_state, mma_pv_fn, mma_one_n_block, @@ -331,15 +393,20 @@ def consume_block_sparse_loads( qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and must be converted to unpacked for sparse tensor indexing. """ - - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors - m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] + ( + curr_mask_block_cnt, + curr_mask_block_idx, + curr_full_block_cnt, + curr_full_block_idx, + ) = get_curr_blocksparse_tensors( + batch_idx, + head_idx, + m_block_sparse, + blocksparse_tensors, + seqlen_info, + ) processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0 @@ -420,7 +487,7 @@ def consume_block_sparse_loads( mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] kv_consumer_state = process_first_half_block( n_block=mask_n_block, - seqlen=seqlen, + seqlen=seqlen_info, kv_consumer_state=kv_consumer_state, mask_fn=partial( mask_fn, @@ -436,7 +503,7 @@ def consume_block_sparse_loads( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=mask_n_block, - seqlen=seqlen, + seqlen=seqlen_info, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), ) @@ -447,7 +514,7 @@ def consume_block_sparse_loads( if curr_mask_block_cnt == 0: kv_consumer_state = process_first_half_block( n_block=full_n_block, - seqlen=seqlen, + seqlen=seqlen_info, kv_consumer_state=kv_consumer_state, mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), score_mod_fn=score_mod_fn, @@ -457,7 +524,7 @@ def consume_block_sparse_loads( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, - seqlen=seqlen, + seqlen=seqlen_info, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), ) @@ -467,7 +534,7 @@ def consume_block_sparse_loads( kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, - seqlen=seqlen, + seqlen=seqlen_info, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), ) @@ -531,6 +598,7 @@ def produce_block_sparse_loads_sm100( batch_idx, head_idx, m_block, + seqlen_info, kv_producer_state, load_Q, load_K, @@ -552,17 +620,18 @@ def produce_block_sparse_loads_sm100( """ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors - - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] - - if const_expr(full_block_cnt is not None): - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] - else: - curr_full_block_cnt = Int32(0) - curr_full_block_idx = None + ( + curr_mask_block_cnt, + curr_mask_block_idx, + curr_full_block_cnt, + curr_full_block_idx, + ) = get_curr_blocksparse_tensors( + batch_idx, + head_idx, + m_block_sparse, + blocksparse_tensors, + seqlen_info, + ) mask_empty = curr_mask_block_cnt == 0 full_empty = curr_full_block_cnt == 0 @@ -626,17 +695,24 @@ def get_total_block_count( m_block, qhead_per_kvhead: cutlass.Constexpr, q_subtile_factor: cutlass.Constexpr, + seqlen_info: SeqlenInfoQK, ): m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors - if const_expr(full_block_cnt is not None): - return ( - mask_block_cnt[batch_idx, head_idx, m_block_sparse] - + full_block_cnt[batch_idx, head_idx, m_block_sparse] - ) + mask_block_cnt, _, full_block_cnt, *_ = blocksparse_tensors + + if const_expr(len(mask_block_cnt.shape) == 2): + # varlen path: tensors are [num_heads, total_m_block] + curr_m = seqlen_info.m_block_offset + m_block_sparse + total = mask_block_cnt[head_idx, curr_m] + if const_expr(full_block_cnt is not None): + total += full_block_cnt[head_idx, curr_m] else: - return mask_block_cnt[batch_idx, head_idx, m_block_sparse] + # non-varlen: tensors are [batch, num_heads, m_block] + total = mask_block_cnt[batch_idx, head_idx, m_block_sparse] + if const_expr(full_block_cnt is not None): + total += full_block_cnt[batch_idx, head_idx, m_block_sparse] + + return total @cute.jit @@ -649,7 +725,7 @@ def handle_block_sparse_empty_tile_correction_sm100( is_split_kv: cutlass.Constexpr, learnable_sink, mLSE, - seqlen, + seqlen_info, m_block: Int32, head_idx: Int32, batch_idx: Int32, @@ -737,7 +813,7 @@ def handle_block_sparse_empty_tile_correction_sm100( tidx, stage, m_block, - seqlen.seqlen_q, + seqlen_info.seqlen_q, Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs sO[None, None, stage], mO_cur, @@ -763,6 +839,7 @@ def softmax_block_sparse_sm100( batch_idx, head_idx, m_block, + seqlen_info, softmax_step: Callable, mask_fn: Callable, mask_fn_none: Callable, @@ -780,17 +857,18 @@ def softmax_block_sparse_sm100( warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors - - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] - - if const_expr(full_block_cnt is not None): - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] - else: - curr_full_block_cnt = Int32(0) - curr_full_block_idx = None + ( + curr_mask_block_cnt, + curr_mask_block_idx, + curr_full_block_cnt, + curr_full_block_idx, + ) = get_curr_blocksparse_tensors( + batch_idx, + head_idx, + m_block_sparse, + blocksparse_tensors, + seqlen_info, + ) total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt @@ -854,7 +932,7 @@ def softmax_block_sparse_sm100( full_n_block, is_first=False, mask_fn=partial( - mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary + mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary ), ) for i in cutlass.range(1, curr_full_block_cnt): diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index 4a5726b7493..8d28edae1dc 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -17,17 +17,23 @@ def ceildiv(a: int, b: int) -> int: class BlockSparseTensors(NamedTuple): mask_block_cnt: cute.Tensor mask_block_idx: cute.Tensor - full_block_cnt: cute.Tensor | None - full_block_idx: cute.Tensor | None + full_block_cnt: cute.Tensor | None = None + full_block_idx: cute.Tensor | None = None + cu_total_m_blocks: cute.Tensor | None = None + cu_block_idx_offsets: cute.Tensor | None = None dq_write_order: cute.Tensor | None = None dq_write_order_full: cute.Tensor | None = None def __new_from_mlir_values__(self, values): - if len(values) == 2: - values = (*values, None, None, None, None) - elif len(values) == 4: - values = (*values, None, None) - return BlockSparseTensors(*values) + new_fields = [] + idx = 0 + for original in self: + if original is None: + new_fields.append(None) + else: + new_fields.append(values[idx]) + idx += 1 + return BlockSparseTensors(*new_fields) class BlockSparseTensorsTorch(NamedTuple): @@ -35,6 +41,8 @@ class BlockSparseTensorsTorch(NamedTuple): mask_block_idx: torch.Tensor full_block_cnt: torch.Tensor | None = None full_block_idx: torch.Tensor | None = None + cu_total_m_blocks: torch.Tensor | None = None + cu_block_idx_offsets: torch.Tensor | None = None block_size: tuple[int, int] | None = None dq_write_order: torch.Tensor | None = None dq_write_order_full: torch.Tensor | None = None @@ -214,8 +222,8 @@ def _check_and_expand_block( name: str, cnt: torch.Tensor | None, idx: torch.Tensor | None, - expected_count_shape: Tuple[int, int, int], - expected_index_shape: Tuple[int, int, int, int], + expected_count_shape: Tuple[int, ...], + expected_index_shape: Tuple[int, ...], context: str | None, hint: str | Callable[[], str] | None, ) -> Tuple[torch.Tensor | None, torch.Tensor | None]: @@ -402,8 +410,8 @@ def get_block_sparse_expected_shapes_bwd( def normalize_block_sparse_tensors( tensors: BlockSparseTensorsTorch, *, - expected_count_shape: Tuple[int, int, int], - expected_index_shape: Tuple[int, int, int, int], + expected_count_shape: Tuple[int, ...], + expected_index_shape: Tuple[int, ...], context: str | None = None, hint: str | Callable[[], str] | None = None, ) -> BlockSparseTensorsTorch: @@ -461,6 +469,8 @@ def normalize_block_sparse_tensors( mask_block_idx=mask_idx, full_block_cnt=full_cnt, full_block_idx=full_idx, + cu_total_m_blocks=tensors.cu_total_m_blocks, + cu_block_idx_offsets=tensors.cu_block_idx_offsets, block_size=tensors.block_size, dq_write_order=dq_write_order, dq_write_order_full=dq_write_order_full, @@ -516,6 +526,13 @@ def normalize_block_sparse_config( block_size: tuple[int, int], q_stage: int, ) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]: + """Validate the block-sparse config, infer expected shapes, and normalize. + + Handles both fixed-length (3D `[B, H, M]` / 4D `[B, H, M, N]`) and varlen + (2D `[H, total_m_blocks]` / `[H, total_n_blocks]`) layouts. Varlen is + detected by `tensors.cu_total_m_blocks is not None` and forces + `q_subtile_factor == 1` (TODO: potentially remove this restriction). + """ m_block_size, n_block_size = block_size if tensors.block_size is None: sparse_block_size_q, sparse_block_size_kv = None, n_block_size @@ -525,21 +542,34 @@ def normalize_block_sparse_config( raise ValueError( f"Block sparsity requires sparse_block_size[1]={n_block_size} to match tile_n." ) - expected_count_shape, expected_index_shape, q_subtile_factor = ( - infer_block_sparse_expected_shapes( - tensors, - batch_size=batch_size, - num_head=num_head, - seqlen_q=seqlen_q, - seqlen_k=seqlen_k, - m_block_size=m_block_size, - n_block_size=n_block_size, - q_stage=q_stage, - context="forward", - sparse_block_size_q=sparse_block_size_q, - sparse_block_size_kv=sparse_block_size_kv, + if tensors.cu_total_m_blocks is not None: + base_m_block = q_stage * m_block_size + if sparse_block_size_q is not None and sparse_block_size_q != base_m_block: + raise ValueError( + f"Varlen block sparsity requires sparse_block_size[0]={base_m_block} " + f"(= q_stage * tile_m); got {sparse_block_size_q}." + ) + total_m_blocks = tensors.mask_block_cnt.shape[-1] + total_n_blocks = tensors.mask_block_idx.shape[-1] + expected_count_shape = (num_head, total_m_blocks) + expected_index_shape = (num_head, total_n_blocks) + q_subtile_factor = 1 + else: + expected_count_shape, expected_index_shape, q_subtile_factor = ( + infer_block_sparse_expected_shapes( + tensors, + batch_size=batch_size, + num_head=num_head, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + m_block_size=m_block_size, + n_block_size=n_block_size, + q_stage=q_stage, + context="forward", + sparse_block_size_q=sparse_block_size_q, + sparse_block_size_kv=sparse_block_size_kv, + ) ) - ) normalized_tensors = normalize_block_sparse_tensors( tensors, expected_count_shape=expected_count_shape, @@ -615,6 +645,12 @@ def to_cute_block_sparse_tensors( else None for t in (tensors.full_block_cnt, tensors.full_block_idx) ] + cu_total_m_blocks_tensor, cu_block_idx_offsets_tensor = [ + to_cute_tensor(t, assumed_align=4, leading_dim=0, enable_tvm_ffi=enable_tvm_ffi) + if t is not None + else None + for t in (tensors.cu_total_m_blocks, tensors.cu_block_idx_offsets) + ] dq_write_order_tensor, dq_write_order_full_tensor = [ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) if t is not None @@ -627,6 +663,8 @@ def to_cute_block_sparse_tensors( mask_block_idx_tensor, full_block_cnt_tensor, full_block_idx_tensor, + cu_total_m_blocks_tensor, + cu_block_idx_offsets_tensor, dq_write_order_tensor, dq_write_order_full_tensor, ) diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py index 69e8309a028..69ef3c77619 100644 --- a/flash_attn/cute/compute_block_sparsity.py +++ b/flash_attn/cute/compute_block_sparsity.py @@ -11,6 +11,13 @@ BlockSparseTensorsTorch, to_cute_block_sparse_tensors, ) +from flash_attn.cute.block_sparse_utils import get_curr_blocksparse_tensors +from flash_attn.cute.testing import is_fake_mode +from flash_attn.cute.cute_dsl_utils import ( + to_cute_tensor, + get_aux_tensor_metadata, + to_cute_aux_tensor, +) from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar from flash_attn.cute.seqlen_info import SeqlenInfoQK @@ -28,7 +35,6 @@ class BlockSparsityKernel: TODO: - optimize mask_mod evaluation - - varlen support - transposed tensors for bwd pass """ @@ -52,18 +58,31 @@ def __call__( blocksparse_tensors: BlockSparseTensors, seqlen_q: Int32, seqlen_k: Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, ): - self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx, *_ = blocksparse_tensors + mask_cnt, mask_idx, full_cnt, full_idx, mCuTotalMBlocks, mCuBlockIdxOffsets, *_ = ( + blocksparse_tensors + ) + + self.is_varlen_q = const_expr(mCuSeqlensQ is not None) if const_expr(self.compute_full_blocks): - assert self.full_cnt is not None and self.full_idx is not None, ( + assert full_cnt is not None and full_idx is not None, ( "full block tensors must be provided when computing full blocks" ) - - batch_size, num_heads, num_m_blocks, num_n_blocks = self.mask_idx.shape - # launch 1 CTA per m block - grid = [num_m_blocks, num_heads, batch_size] + if const_expr(not self.is_varlen_q): + batch_size, num_heads, num_m_blocks, _ = mask_idx.shape + total_m_blocks = batch_size * num_m_blocks + else: + assert const_expr(mCuTotalMBlocks is not None), ( + "mCuTotalMBlocks must be provided when varlen q" + ) + num_heads, total_m_blocks = mask_cnt.shape # num_m_blocks is total_m_blocks + batch_size = mCuSeqlensQ.shape[0] - 1 if const_expr(self.use_fast_sampling): num_threads = 5 @@ -72,46 +91,46 @@ def __call__( num_threads = self.tile_mn[0] self.num_warps = (num_threads + 32 - 1) // 32 + if const_expr(not self.is_varlen_q): + grid = [num_m_blocks, num_heads, batch_size] + else: + grid = [total_m_blocks, num_heads, 1] + self.kernel( - self.mask_cnt, - self.mask_idx, - self.full_cnt, - self.full_idx, - num_n_blocks, + blocksparse_tensors, seqlen_q, seqlen_k, + batch_size, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + mCuTotalMBlocks, + mCuBlockIdxOffsets, aux_tensors, ).launch(grid=grid, block=[num_threads, 1, 1]) @cute.kernel def kernel( self, - mask_cnt: cute.Tensor, - mask_idx: cute.Tensor, - full_cnt: cute.Tensor, - full_idx: cute.Tensor, - num_n_blocks: Int32, + blocksparse_tensors: BlockSparseTensors, seqlen_q: Int32, seqlen_k: Int32, + batch_size: Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, + mCuBlockIdxOffsets: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, ): tidx, _, _ = cute.arch.thread_idx() warp_idx = cute.arch.warp_idx() lane_id = cute.arch.lane_idx() - m_block, head_idx, batch_idx = cute.arch.block_idx() ssa = partial(scalar_to_ssa, dtype=Int32) - seqlen = SeqlenInfoQK.create( - batch_idx, - seqlen_q, - seqlen_k, - mCuSeqlensQ=None, - mCuSeqlensK=None, - mSeqUsedQ=None, - mSeqUsedK=None, - ) - @cute.struct class SharedStorage: reduction_buffer_smem: cute.struct.Align[ @@ -124,132 +143,163 @@ class SharedStorage: reduction_buffer = storage.reduction_buffer_smem.get_tensor( cute.make_layout((self.num_warps, 2)) ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=seqlen_q, + seqlen_k_static=seqlen_k, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + mCuTotalMBlocks=mCuTotalMBlocks, + mCuBlockIdxOffsets=mCuBlockIdxOffsets, + tile_m=self.tile_mn[0], + tile_n=self.tile_mn[1], + ) + + if const_expr(not self.is_varlen_q): + m_block, head_idx, batch_idx = cute.arch.block_idx() + else: + global_m_block, head_idx, _ = cute.arch.block_idx() + # Binary search over cu_total_m_blocks to find batch_idx + lo = Int32(0) + hi = batch_size + while lo < hi: + mid = (lo + hi) // 2 + if mCuTotalMBlocks[mid + 1] <= global_m_block: + lo = mid + 1 + else: + hi = mid + batch_idx = lo + m_block = global_m_block - mCuTotalMBlocks[batch_idx] + + seqlen = SeqlenInfoCls(batch_idx) + seqlen_q = seqlen.seqlen_q + seqlen_k = seqlen.seqlen_k + global_m_block = seqlen.m_block_offset + m_block + + num_n_blocks = (seqlen_k + self.tile_mn[1] - 1) // self.tile_mn[1] + + _, curr_mask_idx, _, curr_full_idx = get_curr_blocksparse_tensors( + batch_idx, head_idx, m_block, blocksparse_tensors, seqlen + ) num_mask_blocks = Int32(0) num_full_blocks = Int32(0) - for n_block in cutlass.range(num_n_blocks, unroll_full=True): - m_base = m_block * self.tile_mn[0] + m_base = m_block * self.tile_mn[0] + if const_expr(self.use_fast_sampling): + # Loop-invariant per-thread q_idx for the 5 sample points + # (tidx 0, 1: top corners; 2, 3: bottom corners; 4: center). + q_idx_sample = m_base + if tidx == 2 or tidx == 3: + q_idx_sample = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1) + elif tidx == 4: + q_idx_sample = m_base + cutlass.min(seqlen_q - m_base, self.tile_mn[0]) // 2 + else: + q_idx_thread = m_base + tidx + thread_in_bounds = Boolean(tidx < self.tile_mn[0] and q_idx_thread < seqlen_q) + + for n_block in cutlass.range(num_n_blocks): n_base = n_block * self.tile_mn[1] if const_expr(self.use_fast_sampling): - # Fast path: 5-point sampling (4 corners + center) - # Clamps OOB indices to nearest in bounds. - thread_result = Boolean(False) - thread_is_valid = Boolean(False) - q_idx = Int32(0) - kv_idx = Int32(0) + # 5-point sampling (4 corners + center). Interior n_blocks + # (n_base + tile_n <= seqlen_k) skip the OOB clamp on the right / + # center samples. + is_interior = (n_base + self.tile_mn[1]) <= seqlen_k + n_right = Int32(0) + n_mid = Int32(0) + if is_interior: + n_right = n_base + self.tile_mn[1] - 1 + n_mid = n_base + self.tile_mn[1] // 2 + else: + n_right = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1) + n_mid = n_base + cutlass.min(seqlen_k - n_base, self.tile_mn[1]) // 2 - if tidx == 0: - # Top-left corner (0, 0); always in bounds - q_idx = m_base - kv_idx = n_base - elif tidx == 1: - # Top-right corner - q_idx = m_base - kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1) - elif tidx == 2: - # Bottom-left corner - q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1) - kv_idx = n_base - elif tidx == 3: - # Bottom-right corner - q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1) - kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1) + kv_idx = n_base + if tidx == 1 or tidx == 3: + kv_idx = n_right elif tidx == 4: - # Center point - q_idx = m_base + (cutlass.min(seqlen_q - m_base, self.tile_mn[0])) // 2 - kv_idx = n_base + (cutlass.min(seqlen_k - n_base, self.tile_mn[1])) // 2 - else: - thread_is_valid = Boolean(False) + kv_idx = n_mid - # Check bounds and determine if this thread has a valid index pair - if tidx < 5 and q_idx < seqlen_q and kv_idx < seqlen_k: + thread_result = Boolean(False) + thread_is_valid = Boolean(False) + if tidx < 5: thread_is_valid = Boolean(True) - q_idx_ssa = ssa(q_idx) - kv_idx_ssa = ssa(kv_idx) thread_result = ssa_to_scalar( self.mask_mod( ssa(batch_idx), ssa(head_idx), - q_idx_ssa, - kv_idx_ssa, + ssa(q_idx_sample), + ssa(kv_idx), seqlen, aux_tensors, ) ) - else: - thread_is_valid = Boolean(False) - # Use vote_any_sync to see if any valid thread found unmasked or masked - # Only count results from threads that checked valid indices has_unmasked = cute.arch.vote_any_sync(thread_result & thread_is_valid) - has_masked = cute.arch.vote_any_sync((Boolean(not thread_result)) & thread_is_valid) + has_masked = cute.arch.vote_any_sync(Boolean(not thread_result) & thread_is_valid) else: - # Full path: check all elements in the block - # Track if this thread's row has any masked or unmasked elements + # Full path. Interior blocks (n_base + tile_n <= seqlen_k) drop the + # per-element bound check; the boundary block (at most one) keeps it. thread_has_unmasked = Boolean(False) thread_has_masked = Boolean(False) - thread_is_valid = Boolean(False) - - # Each thread handles 1 row - q_idx = m_base + tidx kv_idx = Int32(0) - if tidx < self.tile_mn[0] and q_idx < seqlen_q: - thread_is_valid = Boolean(True) - q_idx_ssa = ssa(q_idx) - - # Loop over all columns in this row - for c in cutlass.range(self.tile_mn[1], unroll_full=True): - kv_idx = n_base + c - kv_idx_ssa = ssa(kv_idx) + is_interior = (n_base + self.tile_mn[1]) <= seqlen_k - # Only check elements within valid sequence bounds - if kv_idx < seqlen_k: - # Direct scalar call + if is_interior: + if thread_in_bounds: + for c in cutlass.range(self.tile_mn[1], unroll_full=True): mask_val = ssa_to_scalar( self.mask_mod( ssa(batch_idx), ssa(head_idx), - q_idx_ssa, - kv_idx_ssa, + ssa(q_idx_thread), + ssa(n_base + c), seqlen, aux_tensors, ) ) + thread_has_unmasked |= Boolean(mask_val) + thread_has_masked |= Boolean(not mask_val) + else: + if thread_in_bounds: + for c in cutlass.range(self.tile_mn[1], unroll_full=True): + kv_idx = n_base + c + if kv_idx < seqlen_k: + mask_val = ssa_to_scalar( + self.mask_mod( + ssa(batch_idx), + ssa(head_idx), + ssa(q_idx_thread), + ssa(kv_idx), + seqlen, + aux_tensors, + ) + ) + thread_has_unmasked |= Boolean(mask_val) + thread_has_masked |= Boolean(not mask_val) - # Update tracking flags - if mask_val: - thread_has_unmasked = Boolean(True) - else: - thread_has_masked = Boolean(True) - - # Block-level reduction to combine results across all threads - # Only count votes from threads that checked valid indices - warp_has_unmasked_mask = cute.arch.vote_any_sync( - thread_has_unmasked & thread_is_valid - ) - warp_has_masked_mask = cute.arch.vote_any_sync(thread_has_masked & thread_is_valid) - - # lane 0 writes the ballot mask to shared memory - lane_id = tidx % 32 + warp_unmasked = cute.arch.vote_any_sync(thread_has_unmasked & thread_in_bounds) + warp_masked = cute.arch.vote_any_sync(thread_has_masked & thread_in_bounds) if lane_id == 0: - # Store as Int8 - reduction_buffer[warp_idx, 0] = Int8(1) if warp_has_unmasked_mask else Int8(0) - reduction_buffer[warp_idx, 1] = Int8(1) if warp_has_masked_mask else Int8(0) - + reduction_buffer[warp_idx, 0] = Int8(1) if warp_unmasked else Int8(0) + reduction_buffer[warp_idx, 1] = Int8(1) if warp_masked else Int8(0) cute.arch.sync_threads() - # Thread 0 ORs all warp results together + # Cross-warp OR via warp 0; thread 0 (lane 0 of warp 0) holds the result. has_unmasked = Boolean(False) has_masked = Boolean(False) - if tidx == 0: - for w in cutlass.range(self.num_warps): - if reduction_buffer[w, 0]: - has_unmasked = Boolean(True) - if reduction_buffer[w, 1]: - has_masked = Boolean(True) + if warp_idx == 0: + lane_unmasked = Boolean(False) + lane_masked = Boolean(False) + if lane_id < self.num_warps: + lane_unmasked = reduction_buffer[lane_id, 0] != Int8(0) + lane_masked = reduction_buffer[lane_id, 1] != Int8(0) + has_unmasked = cute.arch.vote_any_sync(lane_unmasked) + has_masked = cute.arch.vote_any_sync(lane_masked) # Only thread 0 updates the output arrays (common to both paths) if tidx == 0: @@ -261,17 +311,23 @@ class SharedStorage: is_full = Boolean(has_unmasked and (not has_masked)) if is_partial: - mask_idx[batch_idx, head_idx, m_block, num_mask_blocks] = n_block + curr_mask_idx[num_mask_blocks] = n_block num_mask_blocks += 1 elif is_full and const_expr(self.compute_full_blocks): - full_idx[batch_idx, head_idx, m_block, num_full_blocks] = n_block + curr_full_idx[num_full_blocks] = n_block num_full_blocks += 1 # Only thread 0 writes back the counts if tidx == 0: - mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks - if const_expr(self.compute_full_blocks): - full_cnt[batch_idx, head_idx, m_block] = num_full_blocks + mask_cnt, _, full_cnt, *_ = blocksparse_tensors + if const_expr(self.is_varlen_q): + mask_cnt[head_idx, global_m_block] = num_mask_blocks + if const_expr(self.compute_full_blocks): + full_cnt[head_idx, global_m_block] = num_full_blocks + else: + mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks + if const_expr(self.compute_full_blocks): + full_cnt[batch_idx, head_idx, m_block] = num_full_blocks def compute_block_sparsity( @@ -282,11 +338,17 @@ def compute_block_sparsity( seqlen_q, seqlen_k, mask_mod: Callable, - aux_tensors: Optional[list], # list[cute.Tensor] + aux_tensors: Optional[list], device, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + cu_total_m_blocks: Optional[torch.Tensor] = None, + cu_block_idx_offsets: Optional[torch.Tensor] = None, compute_full_blocks: bool = True, use_fast_sampling: bool = False, -) -> Tuple[BlockSparseTensors, BlockSparseTensorsTorch]: +) -> BlockSparseTensorsTorch: """ Computes block sparsity for a given `mask_mod`. @@ -300,59 +362,141 @@ def compute_block_sparsity( mask_mod: The `mask_mod` callable to use. aux_tensors: A list of auxiliary tensors. device: The device to use. + cu_seqlens_q: Cumulative q sequence lengths for varlen + cu_seqlens_k: Cumulative k sequence lengths for varlen + seqused_q: Per-batch effective q sequence lengths + seqused_k: Per-batch effective k sequence lengths + cu_total_m_blocks: Cumulative total m blocks tensor for varlen q + cu_block_idx_offsets: Cumulative offsets into the packed mask_block_idx / + full_block_idx tensors per batch (== cumsum of M_b * N_b). compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed. use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient. Returns: - A tuple of `BlockSparseTensors` and `BlockSparseTensorsTorch`. + BlockSparseTensorsTorch """ - # Check if mask_mod is marked as suitable for 5-point fast sampling + # Check if mask_mod is marked as suitable for 5-point sampling use_fast_sampling = getattr(mask_mod, "use_fast_sampling", use_fast_sampling) num_m_blocks = (seqlen_q + tile_m - 1) // tile_m num_n_blocks = (seqlen_k + tile_n - 1) // tile_n - mask_block_cnt = torch.zeros( - (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32 - ) - mask_block_idx = torch.zeros( - (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 - ) - full_block_cnt = ( - torch.zeros((batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32) - if compute_full_blocks - else None - ) - full_block_idx = ( - torch.zeros( + if cu_seqlens_q is not None: + assert cu_total_m_blocks is not None, "total m blocks must be provided when varlen q" + total_m_blocks = cu_total_m_blocks[-1].item() + if cu_block_idx_offsets is None and (cu_seqlens_k is not None or seqused_k is not None): + # Derive cu_block_idx_offsets from per-batch K seqlens. + cu_block_idx_offsets_list = [0] + for batch_idx in range(batch_size): + batch_seqlen_q = cu_seqlens_q[batch_idx + 1].item() - cu_seqlens_q[batch_idx].item() + if cu_seqlens_k is not None: + batch_seqlen_k = ( + cu_seqlens_k[batch_idx + 1].item() - cu_seqlens_k[batch_idx].item() + ) + else: + batch_seqlen_k = seqused_k[batch_idx].item() + num_m_blocks_batch = (batch_seqlen_q + tile_m - 1) // tile_m + num_n_blocks_batch = (batch_seqlen_k + tile_n - 1) // tile_n + cu_block_idx_offsets_list.append( + cu_block_idx_offsets_list[-1] + num_m_blocks_batch * num_n_blocks_batch + ) + cu_block_idx_offsets = torch.tensor( + cu_block_idx_offsets_list, dtype=torch.int32, device=device + ) + if cu_block_idx_offsets is not None: + total_n_blocks = cu_block_idx_offsets[-1].item() + else: + # Uniform-K varlen-Q: every batch has the same K seqlen. + total_n_blocks = total_m_blocks * num_n_blocks + + mask_block_cnt = torch.zeros((num_heads, total_m_blocks), device=device, dtype=torch.int32) + mask_block_idx = torch.zeros((num_heads, total_n_blocks), device=device, dtype=torch.int32) + full_block_cnt = ( + torch.zeros((num_heads, total_m_blocks), device=device, dtype=torch.int32) + if compute_full_blocks + else None + ) + full_block_idx = ( + torch.zeros((num_heads, total_n_blocks), device=device, dtype=torch.int32) + if compute_full_blocks + else None + ) + else: + total_m_blocks = batch_size * num_m_blocks + total_n_blocks = batch_size * num_m_blocks * num_n_blocks + + mask_block_cnt = torch.zeros( + (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + mask_block_idx = torch.zeros( (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 ) - if compute_full_blocks - else None - ) + full_block_cnt = ( + torch.zeros((batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32) + if compute_full_blocks + else None + ) + full_block_idx = ( + torch.zeros( + (batch_size, num_heads, num_m_blocks, num_n_blocks), + device=device, + dtype=torch.int32, + ) + if compute_full_blocks + else None + ) blocksparse_tensors_torch = BlockSparseTensorsTorch( mask_block_cnt=mask_block_cnt, mask_block_idx=mask_block_idx, full_block_cnt=full_block_cnt, full_block_idx=full_block_idx, + cu_total_m_blocks=cu_total_m_blocks, + cu_block_idx_offsets=cu_block_idx_offsets, block_size=(tile_m, tile_n), ) mask_mod_hash = hash_callable(mask_mod) - blocksparse_tensors = to_cute_block_sparse_tensors( - blocksparse_tensors_torch, enable_tvm_ffi=True - ) + if aux_tensors is not None: + aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) + else: + aux_tensor_metadata = None compile_key = ( tile_m, tile_n, mask_mod_hash, + aux_tensor_metadata, compute_full_blocks, + cu_seqlens_q is None, + cu_seqlens_k is None, + seqused_q is None, + seqused_k is None, aux_tensors is not None, use_fast_sampling, ) if compile_key not in compute_block_sparsity.compile_cache: + ( + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + ) = [ + to_cute_tensor(t, assumed_align=4, leading_dim=0) if t is not None else None + for t in ( + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + ) + ] + blocksparse_tensors = to_cute_block_sparse_tensors( + blocksparse_tensors_torch, enable_tvm_ffi=True + ) + if aux_tensors is not None: + cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] + else: + cute_aux_tensors = None kernel = BlockSparsityKernel( mask_mod, tile_mn=(tile_m, tile_n), @@ -362,24 +506,40 @@ def compute_block_sparsity( ) compute_block_sparsity.compile_cache[compile_key] = cute.compile( - kernel, blocksparse_tensors, seqlen_q, seqlen_k, aux_tensors, options="--enable-tvm-ffi" + kernel, + blocksparse_tensors, + seqlen_q, + seqlen_k, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + cute_aux_tensors, + options="--enable-tvm-ffi", ) - compute_block_sparsity.compile_cache[compile_key]( - ( - blocksparse_tensors_torch.mask_block_cnt, - blocksparse_tensors_torch.mask_block_idx, - blocksparse_tensors_torch.full_block_cnt, - blocksparse_tensors_torch.full_block_idx, - blocksparse_tensors_torch.dq_write_order, - blocksparse_tensors_torch.dq_write_order_full, - ), - seqlen_q, - seqlen_k, - aux_tensors, - ) + if not is_fake_mode(): + compute_block_sparsity.compile_cache[compile_key]( + ( + blocksparse_tensors_torch.mask_block_cnt, + blocksparse_tensors_torch.mask_block_idx, + blocksparse_tensors_torch.full_block_cnt, + blocksparse_tensors_torch.full_block_idx, + blocksparse_tensors_torch.cu_total_m_blocks, + blocksparse_tensors_torch.cu_block_idx_offsets, + blocksparse_tensors_torch.dq_write_order, + blocksparse_tensors_torch.dq_write_order_full, + ), + seqlen_q, + seqlen_k, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + aux_tensors, + ) - return blocksparse_tensors, blocksparse_tensors_torch + return blocksparse_tensors_torch compute_block_sparsity.compile_cache = {} diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 42acbeaec86..4d38174c2c8 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -722,6 +722,10 @@ class SharedStorage: self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) if cutlass.const_expr(self.use_block_sparsity and mPageTable is not None): raise NotImplementedError("Block sparsity + paged KV not supported on SM100") + if cutlass.const_expr(self.use_block_sparsity and self.is_varlen_q): + assert const_expr(blocksparse_tensors.cu_total_m_blocks is not None), ( + "blocksparse_tensors.cu_total_m_blocks must be provided for varlen blocksparsity" + ) # Launch the kernel synchronously self.kernel( @@ -1049,6 +1053,12 @@ def kernel( mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, + mCuTotalMBlocks=( + blocksparse_tensors.cu_total_m_blocks if blocksparse_tensors is not None else None + ), + mCuBlockIdxOffsets=( + blocksparse_tensors.cu_block_idx_offsets if blocksparse_tensors is not None else None + ), ) AttentionMaskCls = partial( AttentionMask, @@ -1488,6 +1498,7 @@ def load( batch_idx, head_idx, m_block, + seqlen, kv_producer_state, load_Q, load_K, @@ -1637,6 +1648,7 @@ def mma( m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, + seqlen_info=seqlen, ) process_tile = block_iter_count > Int32(0) else: @@ -2003,6 +2015,7 @@ def softmax_loop( m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, + seqlen_info=seqlen, ) has_work = tile_block_count > Int32(0) else: @@ -2058,6 +2071,7 @@ def softmax_loop( batch_idx, head_idx, m_block, + seqlen, softmax_step, mask_fn, mask_fn_none, @@ -2415,6 +2429,7 @@ def correction_loop( m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, + seqlen_info=seqlen, ) has_work = total_block_count > Int32(0) else: diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index 4108ce451ff..bbbe7a063d9 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -553,6 +553,12 @@ def kernel( mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, + mCuTotalMBlocks=( + blocksparse_tensors.cu_total_m_blocks if blocksparse_tensors is not None else None + ), + mCuBlockIdxOffsets=( + blocksparse_tensors.cu_block_idx_offsets if blocksparse_tensors is not None else None + ), # Don't need to pass in tile_mn because we won't access offset_padded ) AttentionMaskCls = partial( diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index bcfc406b023..0bd4651190c 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -55,6 +55,7 @@ to_cute_block_sparse_tensors, normalize_block_sparse_config, normalize_block_sparse_config_bwd, + get_block_sparse_broadcast_pattern, ) def _parse_arch_str(arch_str): @@ -598,32 +599,25 @@ def _flash_attn_fwd( is_dense_noncausal = not is_varlen and not causal and not local use_clc_scheduler = requested_use_clc_scheduler and not is_varlen_mha and not is_dense_noncausal - if mask_mod is not None: - if is_varlen: - raise NotImplementedError( - "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." - ) - if use_block_sparsity: - if is_varlen: - raise NotImplementedError( - "Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR." - ) # NB: pack_gqa requires block sparse head dim == 1 (broadcasted) - if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1: + head_dim_idx = 0 if block_sparse_tensors.mask_block_cnt.ndim == 2 else 1 + if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[head_dim_idx] != 1: pack_gqa = False if is_split_kv: raise NotImplementedError( "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." ) + if cu_seqlens_q is not None: + assert block_sparse_tensors.cu_total_m_blocks is not None, ( + "Varlen block sparsity requires block_sparse_tensors.cu_total_m_blocks." + ) # See get_broadcast_dims for why this is needed in compile key block_sparse_broadcast_pattern = None normalized_block_sparse_tensors = None q_subtile_factor = None if block_sparse_tensors is not None: - if seqlen_q is None: - raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") ( normalized_block_sparse_tensors, block_sparse_broadcast_pattern, @@ -703,6 +697,8 @@ def _flash_attn_fwd( q_descale is not None, k_descale is not None, v_descale is not None, + block_sparse_tensors is None or block_sparse_tensors.cu_total_m_blocks is None, + block_sparse_tensors is None or block_sparse_tensors.cu_block_idx_offsets is None, tile_m, tile_n, q_stage, @@ -973,13 +969,17 @@ def _flash_attn_fwd( window_size_left, window_size_right, learnable_sink_tensor, - sparse_tensors, - cute_aux_tensors, - current_stream, ] if arch // 10 in [10, 11]: - compile_args.insert(-3, descale_tensors_tensor) - _flash_attn_fwd.compile_cache[compile_key] = cute.compile(*compile_args, options="--enable-tvm-ffi") + compile_args.append(descale_tensors_tensor) + compile_args.extend([ + sparse_tensors, + cute_aux_tensors, + ]) + compile_args.append(current_stream) + _flash_attn_fwd.compile_cache[compile_key] = cute.compile( + *compile_args, options="--enable-tvm-ffi" + ) if not is_fake_mode(): q_call, k_call, v_call = q.detach(), k.detach(), v.detach() @@ -1039,6 +1039,8 @@ def _flash_attn_fwd( normalized_block_sparse_tensors.mask_block_idx, normalized_block_sparse_tensors.full_block_cnt, normalized_block_sparse_tensors.full_block_idx, + normalized_block_sparse_tensors.cu_total_m_blocks, + normalized_block_sparse_tensors.cu_block_idx_offsets, normalized_block_sparse_tensors.dq_write_order, normalized_block_sparse_tensors.dq_write_order_full, ) @@ -1861,6 +1863,8 @@ def _flash_attn_bwd( normalized_block_sparse_tensors.mask_block_idx, normalized_block_sparse_tensors.full_block_cnt, normalized_block_sparse_tensors.full_block_idx, + normalized_block_sparse_tensors.cu_total_m_blocks, + normalized_block_sparse_tensors.cu_block_idx_offsets, normalized_block_sparse_tensors.dq_write_order, normalized_block_sparse_tensors.dq_write_order_full, ) @@ -2027,6 +2031,8 @@ def forward( deterministic: bool = False, score_mod: Optional[Callable] = None, score_mod_bwd: Optional[Callable] = None, + mask_mod: Optional[Callable] = None, + block_sparse_tensors: Optional[list] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, ): @@ -2052,6 +2058,8 @@ def forward( num_splits=num_splits, pack_gqa=pack_gqa, score_mod=score_mod, + mask_mod=mask_mod, + block_sparse_tensors=block_sparse_tensors, aux_tensors=aux_tensors, return_lse=return_lse, gather_kv_indices=gather_kv_indices, @@ -2187,6 +2195,8 @@ def flash_attn_varlen_func( deterministic: bool = False, score_mod: Optional[Callable] = None, score_mod_bwd: Optional[Callable] = None, + mask_mod: Optional[Callable] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, ): @@ -2228,6 +2238,8 @@ def flash_attn_varlen_func( deterministic, score_mod, score_mod_bwd, + mask_mod, + block_sparse_tensors, aux_tensors, return_lse, ) diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index 8e8fdf69ddc..c8ba5672664 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -71,6 +71,9 @@ class SeqlenInfoQK: padded_offset_k: Int32 seqlen_q: Int32 seqlen_k: Int32 + m_block_offset: Int32 + block_idx_offset: Int32 + num_n_blocks: Int32 has_cu_seqlens_q: cutlass.Constexpr[bool] has_cu_seqlens_k: cutlass.Constexpr[bool] has_seqused_q: cutlass.Constexpr[bool] @@ -85,6 +88,8 @@ def create( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, + mCuBlockIdxOffsets: Optional[cute.Tensor] = None, tile_m: cutlass.Constexpr[Int32] = 128, tile_n: cutlass.Constexpr[Int32] = 128, ): @@ -116,6 +121,13 @@ def create( if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx + 1] - offset_k ) + m_block_offset = 0 if const_expr(mCuTotalMBlocks is None) else mCuTotalMBlocks[batch_idx] + num_n_blocks = (seqlen_k + tile_n - 1) // tile_n + block_idx_offset = ( + mCuBlockIdxOffsets[batch_idx] + if const_expr(mCuBlockIdxOffsets is not None) + else m_block_offset * num_n_blocks + ) return SeqlenInfoQK( offset_q, offset_k, @@ -123,6 +135,9 @@ def create( padded_offset_k, seqlen_q, seqlen_k, + m_block_offset, + block_idx_offset, + num_n_blocks, has_cu_seqlens_q=mCuSeqlensQ is not None, has_cu_seqlens_k=mCuSeqlensK is not None, has_seqused_q=mSeqUsedQ is not None, diff --git a/tests/cute/benchmark_block_sparsity.py b/tests/cute/benchmark_block_sparsity.py index ed6bfad2daa..c7f3229d464 100644 --- a/tests/cute/benchmark_block_sparsity.py +++ b/tests/cute/benchmark_block_sparsity.py @@ -39,7 +39,7 @@ class BenchmarkConfig: mask_name: str tile_m: int = 128 tile_n: int = 128 - use_fast_sampling: bool = False + use_fast_sampling: bool = True aux_tensors_cute: Optional[list] = None @@ -162,7 +162,7 @@ def benchmark_cute_block_sparsity( blocksparse_tensors, config.seqlen_q, config.seqlen_k, - config.aux_tensors_cute, + aux_tensors=config.aux_tensors_cute, ) def generate_tensors(): @@ -336,7 +336,7 @@ def main(): mask_name=config.mask_name, tile_m=config.tile_m, tile_n=config.tile_n, - use_fast_sampling=False, + use_fast_sampling=True, aux_tensors_cute=[doc_ids_cute], ) ) diff --git a/tests/cute/mask_mod_definitions.py b/tests/cute/mask_mod_definitions.py index 0820c6f5271..38514f85e19 100644 --- a/tests/cute/mask_mod_definitions.py +++ b/tests/cute/mask_mod_definitions.py @@ -159,7 +159,6 @@ def cute_document_mask( return m_doc == n_doc -@fast_sampling @cute.jit def cute_ima_mask( batch: cute.TensorSSA, @@ -180,7 +179,99 @@ def cute_ima_mask( # n_idx_global = n_idx + seqlen_info.offset_k # ============================================================================= -# TODO: Add varlen mask implementations here + +@fast_sampling +@cute.jit +def cute_global_packed_doc_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + """Document mask using globally-indexed packed 1D doc ID tensors. + + aux_tensors[0]: doc_ids_q (total_q,) int32 — packed doc IDs for Q tokens + aux_tensors[1]: doc_ids_k (total_k,) int32 — packed doc IDs for K tokens + Mask: doc_ids_q[m_global] == doc_ids_k[n_global] + """ + doc_ids_q = aux_tensors[0] + doc_ids_k = aux_tensors[1] + + offset_q = seqlen_info.offset_q + m_global = m_idx + offset_q + m_frag = cute.make_fragment(1, cutlass.Int32) + m_frag.store(m_global) + m_doc_frag = cute.make_fragment(1, cutlass.Int32) + m_doc_frag[0] = doc_ids_q[m_frag[0]] + + offset_k = seqlen_info.offset_k + n_global = n_idx + offset_k + n_frag = cute.make_fragment(1, cutlass.Int32) + n_frag.store(n_global) + n_doc_frag = cute.make_fragment(1, cutlass.Int32) + n_doc_frag[0] = doc_ids_k[n_frag[0]] + + m_doc = m_doc_frag.load() + n_doc = n_doc_frag.load() + return m_doc == n_doc + + +@fast_sampling +@cute.jit +def cute_global_ima_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + """IMA-style mask using globally-indexed threshold tensor. + + aux_tensors[0]: thresholds (total_k,) int32 — per-global-kv-position threshold + Mask: n_idx >= thresholds[n_global] (local n_idx >= globally-indexed threshold) + """ + thresholds = aux_tensors[0] + + offset_k = seqlen_info.offset_k + n_global = n_idx + offset_k + n_frag = cute.make_fragment(1, cutlass.Int32) + n_frag.store(n_global) + val_frag = cute.make_fragment(1, cutlass.Int32) + val_frag[0] = thresholds[n_frag[0]] + threshold = val_frag.load() + + return n_idx >= threshold + + +@fast_sampling +@cute.jit +def cute_global_causal_window_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + """Causal window mask with per-token window sizes indexed globally. + + aux_tensors[0]: windows (total_q,) int32 — per-global-q-position window size + Mask: (n_idx <= m_idx) & (m_idx - n_idx <= windows[m_global]) + """ + windows = aux_tensors[0] + + offset_q = seqlen_info.offset_q + m_global = m_idx + offset_q + m_frag = cute.make_fragment(1, cutlass.Int32) + m_frag.store(m_global) + win_frag = cute.make_fragment(1, cutlass.Int32) + win_frag[0] = windows[m_frag[0]] + window = win_frag.load() + + return (n_idx <= m_idx) & ((m_idx - n_idx) <= window) # ============================================================================= @@ -246,6 +337,61 @@ def flex_ima_mask(b, h, q_idx, kv_idx, bias): return kv_idx >= bias[kv_idx] +# ============================================================================= +# Flex reference factories for global-index masks (per-sequence) +# Each factory(seq_idx, sq, sk) -> mask(b, h, q_idx, kv_idx) +# where q_idx/kv_idx are local (0-indexed) within the sequence. +# ============================================================================= + + +def global_packed_doc_flex_factory(doc_ids_q, doc_ids_k, cu_seqlens_q, cu_seqlens_k): + """Factory for per-sequence flex reference of cute_global_packed_doc_mask.""" + + def factory(seq_idx, sq, sk): + q_offset = cu_seqlens_q[seq_idx].item() + k_offset = cu_seqlens_k[seq_idx].item() + + def mask(b, h, q_idx, kv_idx): + q_global = q_offset + q_idx + k_global = k_offset + kv_idx + return doc_ids_q[q_global] == doc_ids_k[k_global] + + return mask + + return factory + + +def global_ima_flex_factory(thresholds, cu_seqlens_k): + """Factory for per-sequence flex reference of cute_global_ima_mask.""" + + def factory(seq_idx, sq, sk): + k_offset = cu_seqlens_k[seq_idx].item() + + def mask(b, h, q_idx, kv_idx): + k_global = k_offset + kv_idx + return kv_idx >= thresholds[k_global] + + return mask + + return factory + + +def global_causal_window_flex_factory(windows, cu_seqlens_q): + """Factory for per-sequence flex reference of cute_global_causal_window_mask.""" + + def factory(seq_idx, sq, sk): + q_offset = cu_seqlens_q[seq_idx].item() + + def mask(b, h, q_idx, kv_idx): + q_global = q_offset + q_idx + window = windows[q_global] + return (kv_idx <= q_idx) & ((q_idx - kv_idx) <= window) + + return mask + + return factory + + # ============================================================================= # Utility functions # ============================================================================= @@ -253,7 +399,9 @@ def flex_ima_mask(b, h, q_idx, kv_idx, bias): def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): """Generate synthetic document ids shared across heads.""" - doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device) + doc_ids_tensor = torch.zeros( + batch, nheads, seqlen_q, dtype=torch.int32, device=device + ) for b in range(batch): N = seqlen_q max_segments = max(1, math.ceil(math.sqrt(max(N // 4, 1)))) @@ -271,6 +419,83 @@ def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): return doc_ids_tensor +def make_packed_doc_ids(seqlens_q, seqlens_k, device="cuda"): + """Generate packed 1D doc ID tensors for Q and K for varlen global-index testing. + + For each sequence, divides tokens into sqrt(len)-ish segments. + Returns (doc_ids_q, doc_ids_k) of shape (total_q,) and (total_k,). + """ + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + doc_ids_q = torch.zeros(total_q, dtype=torch.int32, device=device) + doc_ids_k = torch.zeros(total_k, dtype=torch.int32, device=device) + + q_off = 0 + k_off = 0 + for sq, sk in zip(seqlens_q, seqlens_k): + # Q doc IDs + n_docs = max(1, math.ceil(math.sqrt(max(sq // 4, 1)))) + n_docs = min(n_docs, sq) + if n_docs > 1 and sq > 1: + cuts = sorted(random.sample(range(1, sq), min(n_docs - 1, sq - 1))) + else: + cuts = [] + lengths = [b - a for a, b in zip((0, *cuts), (*cuts, sq))] + doc_ids_q[q_off : q_off + sq] = torch.repeat_interleave( + torch.arange(len(lengths), dtype=torch.int32, device=device), + torch.tensor(lengths, dtype=torch.int32, device=device), + ) + + # K doc IDs (same n_docs range for potential overlap) + if n_docs > 1 and sk > 1: + cuts_k = sorted(random.sample(range(1, sk), min(n_docs - 1, sk - 1))) + else: + cuts_k = [] + lengths_k = [b - a for a, b in zip((0, *cuts_k), (*cuts_k, sk))] + doc_ids_k[k_off : k_off + sk] = torch.repeat_interleave( + torch.arange(len(lengths_k), dtype=torch.int32, device=device), + torch.tensor(lengths_k, dtype=torch.int32, device=device), + ) + + q_off += sq + k_off += sk + + return doc_ids_q, doc_ids_k + + +def make_global_thresholds(seqlens_k, device="cuda"): + """Generate per-global-kv-token thresholds for cute_global_ima_mask. + + For each K token at local index i in a sequence of length sk, + threshold = random value in [0, sk//2]. + Returns thresholds of shape (total_k,). + """ + total_k = sum(seqlens_k) + thresholds = torch.zeros(total_k, dtype=torch.int32, device=device) + k_off = 0 + for sk in seqlens_k: + for i in range(sk): + thresholds[k_off + i] = random.randint(0, max(0, sk // 2)) + k_off += sk + return thresholds + + +def make_global_windows(seqlens_q, device="cuda"): + """Generate per-global-q-token window sizes for cute_global_causal_window_mask. + + For Q token at local index i, window = random value in [0, i] (causal). + Returns windows of shape (total_q,). + """ + total_q = sum(seqlens_q) + windows = torch.zeros(total_q, dtype=torch.int32, device=device) + q_off = 0 + for sq in seqlens_q: + for i in range(sq): + windows[q_off + i] = random.randint(0, i) + q_off += sq + return windows + + # ============================================================================= # Mask registry and factory functions # ============================================================================= diff --git a/tests/cute/test_block_sparsity.py b/tests/cute/test_block_sparsity.py index 18d578d080f..d4624c8f014 100644 --- a/tests/cute/test_block_sparsity.py +++ b/tests/cute/test_block_sparsity.py @@ -24,7 +24,7 @@ def _call_compute_block_sparsity( cute_mask, _ = get_mask_pair( mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size ) - _, torch_tensors = compute_block_sparsity( + torch_tensors = compute_block_sparsity( tile_m=tile_m, tile_n=tile_n, batch_size=batch_size, @@ -481,5 +481,382 @@ def test_fast_sampling(seqlen_q, seqlen_k, tile_m, tile_n, nheads, mask_name): assert all_match, f"Mismatch: {error_msg}" +def _compare_block_sparsity_varlen( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + seqlens_q, + seqlens_k, + nheads, + tile_m, + tile_n, + mask_name, + window_size=None, +): + """Compare varlen block sparsity against per-sequence fixed-length references.""" + batch_size = len(seqlens_q) + cu_m = cu_total_m_blocks.cpu().tolist() + cu_n = cu_block_idx_offsets.cpu().tolist() + + for b in range(batch_size): + sq, sk = seqlens_q[b], seqlens_k[b] + num_m = (sq + tile_m - 1) // tile_m + num_n = (sk + tile_n - 1) // tile_n + m_off = cu_m[b] + n_off = cu_n[b] + + _, mask_mod_flex = get_mask_pair( + mask_name, seqlen_q=sq, seqlen_k=sk, window_size=window_size + ) + block_mask = create_block_mask( + mask_mod_flex, + B=1, + H=nheads, + Q_LEN=sq, + KV_LEN=sk, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + _, _, ref_mask_cnt, ref_mask_idx, ref_full_cnt, ref_full_idx, *_ = ( + block_mask.as_tuple() + ) + + for h in range(nheads): + for m in range(num_m): + global_m = m_off + m + n_base = n_off + m * num_n + + vl_mask_cnt = mask_block_cnt[h, global_m].item() + vl_full_cnt = full_block_cnt[h, global_m].item() + vl_mask_set = set( + mask_block_idx[h, n_base : n_base + vl_mask_cnt].tolist() + ) + vl_full_set = set( + full_block_idx[h, n_base : n_base + vl_full_cnt].tolist() + ) + + r_mask_cnt = ref_mask_cnt[0, h, m].item() + r_full_cnt = ref_full_cnt[0, h, m].item() + r_mask_set = set(ref_mask_idx[0, h, m, :r_mask_cnt].tolist()) + r_full_set = set(ref_full_idx[0, h, m, :r_full_cnt].tolist()) + + last_m_block = (sq - 1) // tile_m + last_n_block = (sk - 1) // tile_n + m_is_boundary = sq % tile_m != 0 and m == last_m_block + n_is_boundary = sk % tile_n != 0 + + def is_boundary_affected( + n_block, + _m_bnd=m_is_boundary, + _n_bnd=n_is_boundary, + _ln=last_n_block, + ): + return _m_bnd or (_n_bnd and n_block == _ln) + + non_boundary_vl_full = { + n for n in vl_full_set if not is_boundary_affected(n) + } + non_boundary_ref_full = { + n for n in r_full_set if not is_boundary_affected(n) + } + if non_boundary_vl_full != non_boundary_ref_full: + return False, ( + f"Varlen full block mismatch at batch={b}, head={h}, m_block={m} " + f"(sq={sq}, sk={sk}): " + f"varlen={sorted(non_boundary_vl_full)}, ref={sorted(non_boundary_ref_full)}" + ) + + non_boundary_vl_mask = { + n for n in vl_mask_set if not is_boundary_affected(n) + } + non_boundary_ref_mask = { + n for n in r_mask_set if not is_boundary_affected(n) + } + if non_boundary_vl_mask != non_boundary_ref_mask: + return False, ( + f"Varlen partial block mismatch at batch={b}, head={h}, m_block={m} " + f"(sq={sq}, sk={sk}): " + f"varlen={sorted(non_boundary_vl_mask)}, ref={sorted(non_boundary_ref_mask)}" + ) + + return True, "" + + +# ---- Varlen test configurations ---- + +VARLEN_SEQLEN_CONFIGS = [ + # (seqlens_q, seqlens_k) - lists of per-batch lengths + # Uniform lengths (should match fixed-length behavior) + ([128, 128], [128, 128]), + ([256, 256], [256, 256]), + # Different lengths per batch + ([64, 128], [64, 128]), + ([128, 256], [128, 256]), + ([256, 512], [256, 512]), + ([64, 128, 256], [64, 128, 256]), + # Unaligned + ([113, 203], [113, 203]), + ([127, 255], [127, 255]), + ([100, 200, 300], [100, 200, 300]), + # Asymmetric Q/K + ([128, 256], [256, 128]), + ([64, 128], [128, 256]), + # Single element sequences + ([1, 128], [1, 128]), + ([64, 1], [64, 1]), + # Large spread + ([32, 512, 128], [32, 512, 128]), + ([1024, 64], [1024, 64]), +] + + +def _generate_varlen_inputs( + seqlens_q, + seqlens_k, + tile_m, + tile_n, + device="cuda", +): + """Generate cu_seqlens and cu_total_*_blocks for a varlen batch. + + Args: + seqlens_q: list of per-batch query sequence lengths + seqlens_k: list of per-batch key sequence lengths + tile_m, tile_n: tile sizes + Returns: + cu_seqlens_q, cu_seqlens_k, cu_total_m_blocks, cu_block_idx_offsets + """ + batch_size = len(seqlens_q) + assert len(seqlens_k) == batch_size + + cu_seqlens_q = [0] + cu_seqlens_k = [0] + cu_total_m_blocks = [0] + cu_block_idx_offsets = [0] + + for b in range(batch_size): + cu_seqlens_q.append(cu_seqlens_q[-1] + seqlens_q[b]) + cu_seqlens_k.append(cu_seqlens_k[-1] + seqlens_k[b]) + num_m = (seqlens_q[b] + tile_m - 1) // tile_m + num_n = (seqlens_k[b] + tile_n - 1) // tile_n + cu_total_m_blocks.append(cu_total_m_blocks[-1] + num_m) + cu_block_idx_offsets.append(cu_block_idx_offsets[-1] + num_m * num_n) + + return ( + torch.tensor(cu_seqlens_q, device=device, dtype=torch.int32), + torch.tensor(cu_seqlens_k, device=device, dtype=torch.int32), + torch.tensor(cu_total_m_blocks, device=device, dtype=torch.int32), + torch.tensor(cu_block_idx_offsets, device=device, dtype=torch.int32), + ) + + +def _call_compute_block_sparsity_varlen( + seqlens_q, + seqlens_k, + nheads, + tile_m, + tile_n, + mask_name, + window_size=None, + aux_tensors=None, + use_fast_sampling=False, +): + """Call compute_block_sparsity with varlen inputs.""" + batch_size = len(seqlens_q) + # Use max seqlens for mask_mod compilation (the kernel uses per-batch seqlens at runtime) + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + + cute_mask, _ = get_mask_pair( + mask_name, seqlen_q=max_seqlen_q, seqlen_k=max_seqlen_k, window_size=window_size + ) + + cu_seqlens_q, cu_seqlens_k, cu_total_m_blocks, cu_block_idx_offsets = ( + _generate_varlen_inputs(seqlens_q, seqlens_k, tile_m, tile_n) + ) + + torch_tensors = compute_block_sparsity( + tile_m=tile_m, + tile_n=tile_n, + batch_size=batch_size, + num_heads=nheads, + seqlen_q=max_seqlen_q, + seqlen_k=max_seqlen_k, + mask_mod=cute_mask, + aux_tensors=aux_tensors, + device="cuda", + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + cu_total_m_blocks=cu_total_m_blocks, + cu_block_idx_offsets=cu_block_idx_offsets, + use_fast_sampling=use_fast_sampling, + ) + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = torch_tensors + return ( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + ) + + +@pytest.mark.parametrize("seqlens_q,seqlens_k", VARLEN_SEQLEN_CONFIGS) +@pytest.mark.parametrize("tile_m,tile_n", [(64, 64), (128, 128), (64, 128), (128, 64)]) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize("mask_name", ["causal", "block_diagonal"]) +def test_varlen(seqlens_q, seqlens_k, tile_m, tile_n, nheads, mask_name): + """Test variable-length sequence support.""" + ( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + ) = _call_compute_block_sparsity_varlen( + seqlens_q, seqlens_k, nheads, tile_m, tile_n, mask_name + ) + + all_match, error_msg = _compare_block_sparsity_varlen( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + seqlens_q, + seqlens_k, + nheads, + tile_m, + tile_n, + mask_name, + ) + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize( + "seqlens_q,seqlens_k", + [ + ([128, 128], [128, 128]), + ([64, 128, 256], [64, 128, 256]), + ([100, 200], [100, 200]), + ], +) +@pytest.mark.parametrize("tile_m,tile_n", [(64, 64), (128, 128)]) +@pytest.mark.parametrize("nheads", [1]) +@pytest.mark.parametrize( + "mask_name,window_size", + [("causal", None), ("sliding_window", 64), ("sliding_window", 256)], +) +def test_varlen_parameterized_masks( + seqlens_q, seqlens_k, tile_m, tile_n, nheads, mask_name, window_size +): + """Test varlen with parameterized masks.""" + # Skip sliding window when any seqlen_q > seqlen_k + if mask_name == "sliding_window" and any( + sq > sk for sq, sk in zip(seqlens_q, seqlens_k) + ): + pytest.skip("Sliding window not supported for seqlen_q > seqlen_k") + + ( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + ) = _call_compute_block_sparsity_varlen( + seqlens_q, seqlens_k, nheads, tile_m, tile_n, mask_name, window_size=window_size + ) + + all_match, error_msg = _compare_block_sparsity_varlen( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + seqlens_q, + seqlens_k, + nheads, + tile_m, + tile_n, + mask_name, + window_size=window_size, + ) + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize("tile_m,tile_n", [(64, 64), (128, 128)]) +def test_varlen_matches_fixed_length(nheads, tile_m, tile_n): + """Verify that varlen with uniform sequence lengths produces identical + results to the fixed-length path.""" + seqlen_q, seqlen_k = 256, 256 + batch_size = 3 + mask_name = "causal" + + # Fixed-length result + fixed_mask_cnt, fixed_mask_idx, fixed_full_cnt, fixed_full_idx = ( + _call_compute_block_sparsity( + batch_size, nheads, seqlen_q, seqlen_k, tile_m, tile_n, mask_name + ) + ) + + # Varlen with uniform lengths + seqlens_q = [seqlen_q] * batch_size + seqlens_k = [seqlen_k] * batch_size + ( + vl_mask_cnt, + vl_mask_idx, + vl_full_cnt, + vl_full_idx, + cu_total_m_blocks, + cu_block_idx_offsets, + ) = _call_compute_block_sparsity_varlen( + seqlens_q, seqlens_k, nheads, tile_m, tile_n, mask_name + ) + + num_m = (seqlen_q + tile_m - 1) // tile_m + num_n = (seqlen_k + tile_n - 1) // tile_n + cu_m = cu_total_m_blocks.cpu().tolist() + cu_n = cu_block_idx_offsets.cpu().tolist() + + for b in range(batch_size): + for h in range(nheads): + for m in range(num_m): + global_m = cu_m[b] + m + n_base = cu_n[b] + m * num_n + + # Counts should match + assert ( + vl_mask_cnt[h, global_m].item() == fixed_mask_cnt[b, h, m].item() + ), f"Mask count mismatch at b={b}, h={h}, m={m}" + assert ( + vl_full_cnt[h, global_m].item() == fixed_full_cnt[b, h, m].item() + ), f"Full count mismatch at b={b}, h={h}, m={m}" + + mc = vl_mask_cnt[h, global_m].item() + fc = vl_full_cnt[h, global_m].item() + vl_mask_set = set(vl_mask_idx[h, n_base : n_base + mc].tolist()) + vl_full_set = set(vl_full_idx[h, n_base : n_base + fc].tolist()) + fixed_mask_set = set(fixed_mask_idx[b, h, m, :mc].tolist()) + fixed_full_set = set(fixed_full_idx[b, h, m, :fc].tolist()) + + assert vl_mask_set == fixed_mask_set, ( + f"Mask idx mismatch at b={b}, h={h}, m={m}: " + f"varlen={sorted(vl_mask_set)}, fixed={sorted(fixed_mask_set)}" + ) + assert vl_full_set == fixed_full_set, ( + f"Full idx mismatch at b={b}, h={h}, m={m}: " + f"varlen={sorted(vl_full_set)}, fixed={sorted(fixed_full_set)}" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index ceef6500b97..9e4c440fe0c 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -295,10 +295,11 @@ def _run_mask_test( doc_ids = random_doc_id_tensor(nheads, batch_size, doc_len, device="cuda").to( dtype=torch.int32, device="cuda" ) - original_flex_mask = mask_mod_flex + doc_ids.__leading_dim__ = 2 + doc_row = doc_ids[0, 0] # (doc_len,); batch_size=1, all heads identical - def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): - return original_flex_mask(b, h, q_idx, kv_idx, doc_ids) + def mask_mod_flex(b, h, q_idx, kv_idx): + return doc_row[q_idx] == doc_row[kv_idx] aux_tensors_arg = [doc_ids] elif mask_name == "ima": diff --git a/tests/cute/test_mask_mod_varlen.py b/tests/cute/test_mask_mod_varlen.py new file mode 100644 index 00000000000..e935d4d4430 --- /dev/null +++ b/tests/cute/test_mask_mod_varlen.py @@ -0,0 +1,895 @@ +# mask_mod varlen test script +# Forward-only +# +# Since flex_attention doesn't support varlen natively, we compare +# results sequence-by-sequence: run the kernel with cu_seqlens (packed), +# then run flex_attention per-sequence and compare. +# +# Usage: +# pytest test_mask_mod_varlen.py -v -s + +import math +import random + +import pytest +import torch +import torch.nn.functional as F +import cutlass +import cutlass.cute as cute +from torch.nn.attention.flex_attention import create_block_mask, flex_attention + +from flash_attn.cute.interface import _flash_attn_fwd +from flash_attn.cute import utils +from flash_attn.cute.compute_block_sparsity import compute_block_sparsity +from mask_mod_definitions import ( + get_mask_pair, + random_doc_id_tensor, + STATIC_MASKS, + PARAMETERIZED_MASK_FACTORIES, + cute_global_packed_doc_mask, + cute_global_ima_mask, + cute_global_causal_window_mask, + global_packed_doc_flex_factory, + global_ima_flex_factory, + global_causal_window_flex_factory, + make_packed_doc_ids, + make_global_thresholds, + make_global_windows, +) + +COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] + + +@pytest.fixture(autouse=True) +def reset_torch_state(): + """Reset torch dynamo/compile state between tests to avoid state pollution.""" + torch._dynamo.reset() + torch.cuda.empty_cache() + yield + torch._dynamo.reset() + torch.cuda.empty_cache() + + +# ============================================================================= +# Seqlen configs for varlen (list of per-sequence lengths) +# ============================================================================= + +SEQLEN_CONFIGS = [ + # Simple cases + ([1], [1]), + ([64], [64]), + ([128], [128]), + # Multiple sequences, same length + ([128, 128], [128, 128]), + ([64, 64, 64], [64, 64, 64]), + # Multiple sequences, varying lengths + ([64, 128], [64, 128]), + ([32, 64, 128], [32, 64, 128]), + ([113, 203], [113, 203]), + ([256, 512], [256, 512]), + # Asymmetric Q/K lengths + ([64, 128], [32, 64]), + ([100, 100], [50, 50]), + # Edge cases + ([1, 1], [1, 1]), + ([1, 256], [1, 256]), + ([256, 1], [256, 1]), + ([17, 33, 65], [17, 33, 65]), + # Larger sequences + ([1024, 1024], [1024, 1024]), + ([256, 512, 256], [128, 256, 128]), +] + +SEQLEN_CONFIGS_SMOKE = [ + ([128, 128], [128, 128]), + ([64, 128], [64, 128]), + ([113, 203], [113, 203]), + ([256, 512], [256, 512]), + ([64, 128], [32, 64]), +] + + +# ============================================================================= +# Helper functions +# ============================================================================= + + +def setup_varlen_tensors( + seqlens_q, seqlens_k, num_heads, num_kv_heads, head_dim, dtype +): + """Create packed Q, K, V tensors and cu_seqlens for varlen.""" + device = "cuda" + batch_size = len(seqlens_q) + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + + q = torch.randn(total_q, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(total_k, num_kv_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(total_k, num_kv_heads, head_dim, device=device, dtype=dtype) + + cu_seqlens_q = torch.tensor( + [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + cu_seqlens_k = torch.tensor( + [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + + return q, k, v, cu_seqlens_q, cu_seqlens_k + + +def run_flex_per_sequence( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + mask_mod_flex_factory, + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype=None, +): + """Run flex_attention per-sequence as reference for varlen. + + mask_mod_flex_factory(seq_idx, seqlen_q_i, seqlen_k_i) -> mask_mod function + that takes (b, h, q_idx, kv_idx) for that sequence. + """ + batch_size = len(seqlens_q) + results = [] + + for i in range(batch_size): + sq = seqlens_q[i] + sk = seqlens_k[i] + + # Extract packed slices + q_slice = q[cu_seqlens_q[i] : cu_seqlens_q[i + 1]].unsqueeze(0) # (1, sq, H, D) + k_slice = k[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze( + 0 + ) # (1, sk, Hkv, D) + v_slice = v[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze(0) + + if dtype is not None: + q_slice = q_slice.to(dtype) + k_slice = k_slice.to(dtype) + v_slice = v_slice.to(dtype) + + # Transpose to (B, H, S, D) for flex_attention + q_t = q_slice.transpose(1, 2) + k_t = k_slice.transpose(1, 2) + v_t = v_slice.transpose(1, 2) + + # Expand KV heads for GQA + if num_heads != num_kv_heads: + repeat_factor = num_heads // num_kv_heads + k_t = k_t.repeat_interleave(repeat_factor, dim=1) + v_t = v_t.repeat_interleave(repeat_factor, dim=1) + + scale = 1.0 / math.sqrt(head_dim) + + mask_mod = mask_mod_flex_factory(i, sq, sk) + + if mask_mod is None: + out = F.scaled_dot_product_attention(q_t, k_t, v_t, scale=scale) + else: + block_mask = create_block_mask( + mask_mod, + B=1, + H=num_heads, + Q_LEN=sq, + KV_LEN=sk, + device=q.device, + ) + out = flex_attention( + q_t, k_t, v_t, block_mask=block_mask, scale=scale, enable_gqa=True + ) + + results.append(out.transpose(1, 2).squeeze(0)) # back to (sq, H, D) + + return torch.cat(results, dim=0) + + +def check_varlen_results( + out_cute, + out_ref_fp32, + out_pt, + seqlens_q, + cu_seqlens_q, + test_name, + rtol=2, + extra_atol=2e-3, +): + """Compare CuTE output against per-sequence flex references.""" + assert not torch.isnan(out_cute).any(), f"{test_name}: NaN in output" + assert torch.isfinite(out_cute).all(), f"{test_name}: Inf in output" + assert out_cute.shape == out_ref_fp32.shape, ( + f"{test_name}: Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}" + ) + + num_seqs = len(seqlens_q) + max_cute_error = 0.0 + max_pt_error = 0.0 + + for i in range(num_seqs): + start = cu_seqlens_q[i] + end = cu_seqlens_q[i + 1] + cute_seq = out_cute[start:end] + ref_seq = out_ref_fp32[start:end] + pt_seq = out_pt[start:end] + + max_cute_error = max(max_cute_error, (cute_seq - ref_seq).abs().max().item()) + max_pt_error = max(max_pt_error, (pt_seq - ref_seq).abs().max().item()) + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + + print(f"\n{test_name}:") + print(f" PyTorch vs FP32 ref: {max_pt_error:.2e}") + print(f" CuTE vs FP32 ref: {max_cute_error:.2e}") + + tol = rtol * max_pt_error + fwd_atol + extra_atol + assert max_cute_error <= tol, ( + f"{test_name}: CuTE error {max_cute_error:.2e} exceeds tolerance {tol:.2e} " + f"(rtol={rtol} * pt_err={max_pt_error:.2e} + fwd_atol={fwd_atol:.2e} + extra={extra_atol:.2e})" + ) + + +# ============================================================================= +# Core test runner +# ============================================================================= + + +def _run_varlen_mask_test( + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype, + mask_name, + window_size=None, +): + """Run a varlen mask_mod test: kernel with cu_seqlens vs per-sequence flex_attention.""" + torch.manual_seed(42) + random.seed(42) + + batch_size = len(seqlens_q) + pack_gqa = num_heads != num_kv_heads + + if mask_name == "sliding_window": + # Skip configs where any seqlen_q > seqlen_k + for sq, sk in zip(seqlens_q, seqlens_k): + if sq > sk: + pytest.skip( + "sliding_window requires seqlen_q <= seqlen_k for each sequence" + ) + + q, k, v, cu_seqlens_q, cu_seqlens_k = setup_varlen_tensors( + seqlens_q, seqlens_k, num_heads, num_kv_heads, head_dim, dtype + ) + + if mask_name == "block_causal": + offsets = [sk - sq for sq, sk in zip(seqlens_q, seqlens_k)] + if len(set(offsets)) > 1: + pytest.skip( + "block_causal captures offset as compile-time constant; " + "varlen with different per-sequence offsets not supported" + ) + + aux_tensors_arg = None + + if mask_name == "document": + max_seqlen = max(max(seqlens_q), max(seqlens_k)) + max_doc_len = max(max(seqlens_q), max(seqlens_k)) + doc_ids = random_doc_id_tensor( + num_heads, batch_size, max_doc_len, device="cuda" + ).to(dtype=torch.int32, device="cuda") + doc_ids.__leading_dim__ = 2 + aux_tensors_arg = [doc_ids] + + from mask_mod_definitions import flex_document_mask + + cute_mask_mod = get_mask_pair("document")[0] + + def flex_factory(seq_idx, sq, sk, doc_ids=doc_ids): + # Pre-slice to 1D using Python ints *outside* the vmapped closure. + # create_block_mask vmaps over all four args (b, h, q_idx, kv_idx); + # multi-dim indexing like doc_id[b, h, q_idx] with 0-dim vmap tensors + # triggers .item() internally. 1D tensor[0d_tensor] is a safe gather. + doc_row = doc_ids[seq_idx, 0] # (max_doc_len,) + + def _mask(b, h, q_idx, kv_idx): + return doc_row[q_idx] == doc_row[kv_idx] + + return _mask + + elif mask_name == "ima": + total_k = sum(seqlens_k) + pytest.skip( + "IMA mask requires global index handling for varlen - not yet implemented" + ) + + else: + if mask_name in STATIC_MASKS: + cute_mask_mod = get_mask_pair(mask_name)[0] + + def flex_factory(seq_idx, sq, sk): + return get_mask_pair(mask_name)[1] + + elif mask_name in PARAMETERIZED_MASK_FACTORIES: + cute_mask_mod = get_mask_pair( + mask_name, + seqlen_q=seqlens_q[0], + seqlen_k=seqlens_k[0], + window_size=window_size, + )[0] + + def flex_factory(seq_idx, sq, sk): + _, flex_mask = get_mask_pair( + mask_name, + seqlen_q=sq, + seqlen_k=sk, + window_size=window_size, + ) + return flex_mask + + else: + raise ValueError(f"Unknown mask: {mask_name}") + + # Run the kernel with varlen (packed format) + out = torch.empty_like(q) + softmax_scale = 1.0 / math.sqrt(head_dim) + + out_tuple = _flash_attn_fwd( + q=q, + k=k, + v=v, + out=out, + lse=None, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=softmax_scale, + causal=False, + softcap=None, + window_size_left=-1, + window_size_right=-1, + learnable_sink=None, + tile_mn=(128, 128), + pack_gqa=pack_gqa, + _arch=None, + score_mod=None, + mask_mod=cute_mask_mod, + block_sparse_tensors=None, + return_lse=True, + aux_tensors=aux_tensors_arg, + ) + out_cute = out_tuple[0] + + # Run per-sequence flex_attention references + out_ref_fp32 = run_flex_per_sequence( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + flex_factory, + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype=torch.float32, + ) + out_pt = run_flex_per_sequence( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + flex_factory, + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype=dtype, + ) + + # Check results + mask_desc = f"mask_mod={mask_name}" + if window_size is not None: + mask_desc += f"(w={window_size})" + test_name = ( + f"{mask_desc} varlen seqs_q={seqlens_q}, seqs_k={seqlens_k}, " + f"H={num_heads}/{num_kv_heads}, D={head_dim}" + ) + check_varlen_results( + out_cute, out_ref_fp32, out_pt, seqlens_q, cu_seqlens_q, test_name + ) + + +# ============================================================================= +# Test cases +# ============================================================================= + +# Masks that don't need recompilation per seqlen (fast) +STATIC_MASK_NAMES = ["block_diagonal", "mini_causal"] + +# Masks that need per-seqlen compilation (slower) +PARAMETERIZED_MASK_CONFIGS = [ + ("causal", None), + ("block_causal", None), + ("sliding_window", 128), + ("sliding_window", 256), + ("document", None), +] + + +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa"]) +@pytest.mark.parametrize("mask_name", STATIC_MASK_NAMES) +def test_varlen_static_masks(seqlens_q, seqlens_k, dtype, kv_mode, mask_name): + """Test static mask_mods with varlen (packed) attention.""" + num_heads = 8 + if kv_mode == "gqa": + if COMPUTE_CAPABILITY < 9: + pytest.xfail("pack_gqa requires SM90+") + num_kv_heads = 2 + else: + num_kv_heads = num_heads + + _run_varlen_mask_test( + seqlens_q=seqlens_q, + seqlens_k=seqlens_k, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=128, + dtype=dtype, + mask_name=mask_name, + ) + + +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa"]) +@pytest.mark.parametrize("mask_name,window_size", PARAMETERIZED_MASK_CONFIGS) +def test_varlen_parameterized_masks( + seqlens_q, seqlens_k, dtype, kv_mode, mask_name, window_size +): + """Test parameterized mask_mods with varlen (packed) attention. + + Uses fewer seqlen configs since these require recompilation per seqlen. + """ + num_heads = 8 + if kv_mode == "gqa": + if COMPUTE_CAPABILITY < 9: + pytest.xfail("pack_gqa requires SM90+") + num_kv_heads = 2 + else: + num_kv_heads = num_heads + + _run_varlen_mask_test( + seqlens_q=seqlens_q, + seqlens_k=seqlens_k, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=128, + dtype=dtype, + mask_name=mask_name, + window_size=window_size, + ) + + +# ============================================================================= +# Global-index mask test runner +# ============================================================================= + + +def _run_varlen_global_mask_test( + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype, + mask_name, +): + """Run a varlen global-index mask_mod test: kernel vs per-sequence flex_attention.""" + torch.manual_seed(42) + random.seed(42) + + pack_gqa = num_heads != num_kv_heads + + q, k, v, cu_seqlens_q, cu_seqlens_k = setup_varlen_tensors( + seqlens_q, seqlens_k, num_heads, num_kv_heads, head_dim, dtype + ) + + if mask_name == "global_packed_doc": + doc_ids_q, doc_ids_k = make_packed_doc_ids(seqlens_q, seqlens_k, device="cuda") + cute_mask = cute_global_packed_doc_mask + flex_fac = global_packed_doc_flex_factory( + doc_ids_q, doc_ids_k, cu_seqlens_q, cu_seqlens_k + ) + aux_tensors = [doc_ids_q, doc_ids_k] + elif mask_name == "global_ima": + thresholds = make_global_thresholds(seqlens_k, device="cuda") + cute_mask = cute_global_ima_mask + flex_fac = global_ima_flex_factory(thresholds, cu_seqlens_k) + aux_tensors = [thresholds] + elif mask_name == "global_causal_window": + windows = make_global_windows(seqlens_q, device="cuda") + cute_mask = cute_global_causal_window_mask + flex_fac = global_causal_window_flex_factory(windows, cu_seqlens_q) + aux_tensors = [windows] + else: + raise ValueError(f"Unknown global mask: {mask_name}") + + out = torch.empty_like(q) + softmax_scale = 1.0 / math.sqrt(head_dim) + + out_tuple = _flash_attn_fwd( + q=q, + k=k, + v=v, + out=out, + lse=None, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=softmax_scale, + causal=False, + softcap=None, + window_size_left=-1, + window_size_right=-1, + learnable_sink=None, + tile_mn=(128, 128), + pack_gqa=pack_gqa, + _arch=None, + score_mod=None, + mask_mod=cute_mask, + block_sparse_tensors=None, + return_lse=True, + aux_tensors=aux_tensors, + ) + out_cute = out_tuple[0] + + out_ref_fp32 = run_flex_per_sequence( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + flex_fac, + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype=torch.float32, + ) + out_pt = run_flex_per_sequence( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + flex_fac, + seqlens_q, + seqlens_k, + num_heads, + num_kv_heads, + head_dim, + dtype=dtype, + ) + + test_name = ( + f"global_mask={mask_name} varlen seqs_q={seqlens_q}, seqs_k={seqlens_k}, " + f"H={num_heads}/{num_kv_heads}, D={head_dim}" + ) + check_varlen_results( + out_cute, out_ref_fp32, out_pt, seqlens_q, cu_seqlens_q, test_name + ) + + +GLOBAL_MASK_NAMES = ["global_packed_doc", "global_ima", "global_causal_window"] + + +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("mask_name", GLOBAL_MASK_NAMES) +def test_varlen_global_masks(seqlens_q, seqlens_k, mask_name): + """Test global-index mask_mods (aux-tensor-driven) with varlen packed attention.""" + _run_varlen_global_mask_test( + seqlens_q, seqlens_k, 8, 8, 128, torch.bfloat16, mask_name + ) + + +# ============================================================================= +# Block sparsity end-to-end tests +# ============================================================================= + + +def _make_block_sparse_tensors( + mask_mod, + seqlens_q, + seqlens_k, + num_heads, + tile_m, + tile_n, + device, + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_k=None, + aux_tensors=None, +): + """Compute block sparse tensors. cu_total_m_blocks / cu_block_idx_offsets are + populated on the returned BlockSparseTensorsTorch.""" + batch_size = len(seqlens_q) + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + + if cu_seqlens_q is not None: + cu_total_m_blocks_list = [0] + for batch_idx in range(batch_size): + num_m_blocks = (seqlens_q[batch_idx] + tile_m - 1) // tile_m + cu_total_m_blocks_list.append(cu_total_m_blocks_list[-1] + num_m_blocks) + cu_total_m_blocks = torch.tensor( + cu_total_m_blocks_list, dtype=torch.int32, device=device + ) + + cu_block_idx_offsets = None + if cu_seqlens_k is not None or seqused_k is not None: + cu_block_idx_offsets_list = [0] + for batch_idx in range(batch_size): + num_m_blocks = (seqlens_q[batch_idx] + tile_m - 1) // tile_m + num_n_blocks = (seqlens_k[batch_idx] + tile_n - 1) // tile_n + cu_block_idx_offsets_list.append( + cu_block_idx_offsets_list[-1] + num_m_blocks * num_n_blocks + ) + cu_block_idx_offsets = torch.tensor( + cu_block_idx_offsets_list, dtype=torch.int32, device=device + ) + else: + cu_total_m_blocks = None + cu_block_idx_offsets = None + + block_sparse_tensors = compute_block_sparsity( + tile_m=tile_m, + tile_n=tile_n, + batch_size=batch_size, + num_heads=num_heads, + seqlen_q=max_seqlen_q, + seqlen_k=max_seqlen_k, + mask_mod=mask_mod, + aux_tensors=aux_tensors, + device=device, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + cu_total_m_blocks=cu_total_m_blocks, + cu_block_idx_offsets=cu_block_idx_offsets, + seqused_k=seqused_k, + ) + return block_sparse_tensors + + +def _run_fwd( + q, + k, + v, + mask_mod, + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_k=None, + block_sparse_tensors=None, + aux_tensors=None, +): + out = torch.empty_like(q) + return _flash_attn_fwd( + q=q, + k=k, + v=v, + out=out, + lse=None, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=None, + seqused_k=seqused_k, + page_table=None, + softmax_scale=1.0 / math.sqrt(q.shape[-1]), + causal=False, + softcap=None, + window_size_left=-1, + window_size_right=-1, + learnable_sink=None, + tile_mn=(128, 128), + pack_gqa=False, + _arch=None, + score_mod=None, + mask_mod=mask_mod, + block_sparse_tensors=block_sparse_tensors, + return_lse=False, + aux_tensors=aux_tensors, + )[0] + + +BLOCK_SPARSE_MASK_NAMES = [ + "causal", + "block_diagonal", + "mini_causal", + "prefix_lm", + "sliding_window", + "dilated_sliding_window", + "document", + "ima", +] + +BLOCK_SPARSE_SEQLEN_CONFIGS = [ + ([128, 192, 256], [128, 192, 256]), + ([64, 128], [64, 128]), + ([256, 512, 256], [256, 512, 256]), + ([128, 192, 256], [64, 128, 192]), +] + + +# @pytest.mark.parametrize("varlen_q", [False, True]) +@pytest.mark.parametrize("seqlens", BLOCK_SPARSE_SEQLEN_CONFIGS) +@pytest.mark.parametrize("mask_name", BLOCK_SPARSE_MASK_NAMES) +@pytest.mark.parametrize("varlen_q", [False, True]) +@pytest.mark.parametrize("varlen_k", [False, True]) +@pytest.mark.parametrize("use_seqused_k", [False, True]) +@pytest.mark.parametrize("head_broadcast", [False, True]) +def test_varlen_block_sparse( + varlen_q, varlen_k, use_seqused_k, head_broadcast, mask_name, seqlens +): + """Block sparsity + mask_mod should produce identical output to mask_mod alone.""" + if varlen_k and use_seqused_k: + pytest.skip("packed K (cu_seqlens_k) and seqused_k are mutually exclusive") + if not varlen_q and varlen_k: + pytest.skip( + "block sparsity with padded Q + packed K requires per-batch n-block offsets; not yet supported" + ) + + torch.manual_seed(42) + random.seed(42) + device = "cuda" + num_heads = 4 + head_dim = 128 + dtype = torch.bfloat16 + # On Blackwell (SM100) q_stage=2 → effective tile is 256; elsewhere 128. + tile_m = 256 if COMPUTE_CAPABILITY >= 10 else 128 + tile_n = 128 + + base_seqlens_q, base_seqlens_k = seqlens + batch_size = len(base_seqlens_q) + max_seqlen_q = max(base_seqlens_q) + max_seqlen_k = max(base_seqlens_k) + + seqlens_q = base_seqlens_q if varlen_q else [max_seqlen_q] * batch_size + seqlens_k = ( + base_seqlens_k if (varlen_k or use_seqused_k) else [max_seqlen_k] * batch_size + ) + + def make_cu_seqlens(seqlens): + return torch.tensor( + [0] + list(torch.tensor(seqlens).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + + if varlen_q: + q = torch.randn(sum(seqlens_q), num_heads, head_dim, device=device, dtype=dtype) + cu_seqlens_q = make_cu_seqlens(seqlens_q) + else: + q = torch.randn( + batch_size, max_seqlen_q, num_heads, head_dim, device=device, dtype=dtype + ) + cu_seqlens_q = None + + if varlen_k: + k = torch.randn(sum(seqlens_k), num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(sum(seqlens_k), num_heads, head_dim, device=device, dtype=dtype) + cu_seqlens_k = make_cu_seqlens(seqlens_k) + seqused_k = None + else: + k = torch.randn( + batch_size, max_seqlen_k, num_heads, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, max_seqlen_k, num_heads, head_dim, device=device, dtype=dtype + ) + cu_seqlens_k = None + seqused_k = ( + torch.tensor(seqlens_k, dtype=torch.int32, device=device) + if use_seqused_k + else None + ) + + # Build mask_mod and aux_tensors for the requested mask + aux_tensors = None + if mask_name == "causal": + mask_mod, _ = get_mask_pair( + "causal", seqlen_q=max_seqlen_q, seqlen_k=max_seqlen_k + ) + elif mask_name == "sliding_window": + mask_mod, _ = get_mask_pair( + "sliding_window", + seqlen_q=max_seqlen_q, + seqlen_k=max_seqlen_k, + window_size=128, + ) + elif mask_name == "document": + max_doc_len = max(max_seqlen_q, max_seqlen_k) + doc_ids = random_doc_id_tensor( + num_heads, batch_size, max_doc_len, device=device + ) + doc_ids.__leading_dim__ = 2 + aux_tensors = [doc_ids] + mask_mod = get_mask_pair("document")[0] + elif mask_name == "ima": + bias = torch.randint( + 0, + max(1, max_seqlen_k // 2), + (max_seqlen_k,), + dtype=torch.int32, + device=device, + ) + aux_tensors = [bias] + mask_mod = get_mask_pair("ima")[0] + else: + mask_mod = get_mask_pair(mask_name)[0] + + num_heads_sparse = 1 if head_broadcast else num_heads + block_sparse_tensors = _make_block_sparse_tensors( + mask_mod=mask_mod, + seqlens_q=seqlens_q, + seqlens_k=seqlens_k, + num_heads=num_heads_sparse, + tile_m=tile_m, + tile_n=tile_n, + device=device, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + aux_tensors=aux_tensors, + ) + + out_with_block_sparsity = _run_fwd( + q, + k, + v, + mask_mod, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + block_sparse_tensors=block_sparse_tensors, + aux_tensors=aux_tensors, + ) + out_no_block_sparsity = _run_fwd( + q, + k, + v, + mask_mod, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + aux_tensors=aux_tensors, + ) + + assert not torch.isnan(out_with_block_sparsity).any(), "NaN in block-sparse output" + max_err = (out_with_block_sparsity - out_no_block_sparsity).abs().max().item() + assert max_err <= 0.01, ( + f"block-sparse output differs from mask-mod-only by {max_err}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) From 09aa322287571356db9fff765e9b10f005a3e9f5 Mon Sep 17 00:00:00 2001 From: wangsiyu Date: Thu, 7 May 2026 17:41:42 +0800 Subject: [PATCH 05/21] [FA4][hd256] Fix layout of non-contiguous qkv in backward kernel (#2545) --- flash_attn/cute/mask.py | 6 ++ ...100_hd256_2cta_fmha_backward_dkdvkernel.py | 61 +++++++++++++++---- ...sm100_hd256_2cta_fmha_backward_dqkernel.py | 57 ++++++++++++++--- 3 files changed, 102 insertions(+), 22 deletions(-) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 9c171ba9865..daa2e9c2d5c 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -871,6 +871,9 @@ def get_trip_start_count_via_block_info( padded_offset_k=Int32(0), seqlen_q=seqlen_q, seqlen_k=seqlen_k, + m_block_offset=Int32(0), + block_idx_offset=Int32(0), + num_n_blocks=cute.ceil_div(seqlen_k, tile_shape[1]), has_cu_seqlens_q=False, has_cu_seqlens_k=False, has_seqused_q=False, @@ -911,6 +914,9 @@ def get_trip_mask_bounds_via_block_info( padded_offset_k=Int32(0), seqlen_q=seqlen_q, seqlen_k=seqlen_k, + m_block_offset=Int32(0), + block_idx_offset=Int32(0), + num_n_blocks=cute.ceil_div(seqlen_k, tile_shape[1]), has_cu_seqlens_q=False, has_cu_seqlens_k=False, has_seqused_q=False, diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py index 7a8cdeede6a..6c9db87458f 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py @@ -244,12 +244,8 @@ def __call__( cute.assume(Q.stride[1], divby=64), Q.stride[4], ( - (Q.shape[4], Q.shape[4] * Q.shape[3]), - ( - 0 - if varlen - else cute.assume(Q.shape[1] * Q.shape[4] * h_r * h_k, divby=64) - ), + (Q.stride[3], Q.stride[2]), + 0 if cumulative_s_q is not None else cute.assume(Q.stride[0], divby=64), ), ), ), @@ -263,8 +259,8 @@ def __call__( cute.assume(K.stride[1], divby=64), K.stride[4], ( - (0, K.shape[4]), - (0 if varlen else cute.assume(K.shape[1] * K.shape[4] * 1 * h_k, divby=64)), + (0, K.stride[2]), + 0 if cumulative_s_k is not None else cute.assume(K.stride[0], divby=64), ), ), ), @@ -278,8 +274,8 @@ def __call__( cute.assume(V.stride[1], divby=64), V.stride[4], ( - (0, V.shape[4]), - (0 if varlen else cute.assume(V.shape[1] * V.shape[4] * 1 * h_k, divby=64)), + (0, V.stride[2]), + 0 if cumulative_s_k is not None else cute.assume(V.stride[0], divby=64), ), ), ), @@ -296,10 +292,49 @@ def __call__( ), ), ) - dK = cute.make_tensor(dK.iterator, K.layout) - dV = cute.make_tensor(dV.iterator, V.layout) + dK = cute.make_tensor( + dK.iterator, + cute.make_layout( + (dK.shape[1], dK.shape[4], hb), + stride=( + cute.assume(dK.stride[1], divby=64), + dK.stride[4], + ( + (0, dK.stride[2]), + 0 if cumulative_s_k is not None else cute.assume(dK.stride[0], divby=64), + ), + ), + ), + ) + dV = cute.make_tensor( + dV.iterator, + cute.make_layout( + (dV.shape[1], dV.shape[4], hb), + stride=( + cute.assume(dV.stride[1], divby=64), + dV.stride[4], + ( + (0, dV.stride[2]), + 0 if cumulative_s_k is not None else cute.assume(dV.stride[0], divby=64), + ), + ), + ), + ) # (s, d, ((h_r, h_k), b)) - dO = cute.make_tensor(dO.iterator, Q.layout) + dO = cute.make_tensor( + dO.iterator, + cute.make_layout( + (dO.shape[1], dO.shape[4], hb), + stride=( + cute.assume(dO.stride[1], divby=64), + dO.stride[4], + ( + (dO.stride[3], dO.stride[2]), + 0 if cumulative_s_q is not None else cute.assume(dO.stride[0], divby=64), + ), + ), + ), + ) # (s, d, ((h_r, h_k), b)) -> (d, s, ((h_r, h_k), b)) dOT = cute.make_tensor( diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py index b25ca48f007..a95c9677b65 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py @@ -187,7 +187,6 @@ def __call__( s_q64 = Int64(s_q) s_k64 = Int64(s_k) s_lse64 = Int64(s_lse) - d64 = cute.assume(Int64(d), divby=128) h_r64 = Int64(h_r) h_k64 = Int64(h_k) b64 = Int64(b) @@ -196,39 +195,72 @@ def __call__( # `cuseqlen_*` offsets stays within the tensor domain. s_q_total = q_tensor.shape[1] if cum_seqlen_q is not None else s_q64 s_k_total = k_tensor.shape[1] if cum_seqlen_k is not None else s_k64 - stride_b_qo = h_r64 * h_k64 * s_q64 * d64 if cum_seqlen_q is None else 0 - stride_b_kv = h_k64 * s_k64 * d64 if cum_seqlen_k is None else 0 b_lse = b64 if cum_seqlen_q is None else 1 stride_b_lse = h_r64 * h_k64 * s_lse64 if cum_seqlen_q is None else 0 # (s, d, ((h_r, h_k), b)) q_layout = cute.make_layout( (s_q_total, d, ((h_r, h_k), b)), - stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + stride=( + cute.assume(q_tensor.stride[1], divby=64), + q_tensor.stride[4], + ( + (q_tensor.stride[3], q_tensor.stride[2]), + 0 if cum_seqlen_q is not None else cute.assume(q_tensor.stride[0], divby=64), + ), + ), ) q = cute.make_tensor(q_tensor.iterator, q_layout) # (s, d, ((h_r, h_k), b)) do_layout = cute.make_layout( (s_q_total, d, ((h_r, h_k), b)), - stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + stride=( + cute.assume(do_tensor.stride[1], divby=64), + do_tensor.stride[4], + ( + (do_tensor.stride[3], do_tensor.stride[2]), + 0 if cum_seqlen_q is not None else cute.assume(do_tensor.stride[0], divby=64), + ), + ), ) do = cute.make_tensor(do_tensor.iterator, do_layout) # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast k_layout = cute.make_layout( (s_k_total, d, ((h_r, h_k), b)), - stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)), + stride=( + cute.assume(k_tensor.stride[1], divby=64), + k_tensor.stride[4], + ( + (0, k_tensor.stride[2]), + 0 if cum_seqlen_k is not None else cute.assume(k_tensor.stride[0], divby=64), + ), + ), ) k = cute.make_tensor(k_tensor.iterator, k_layout) # (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast kt_layout = cute.make_layout( (d, s_k_total, ((h_r, h_k), b)), - stride=(1, d64 * h_k64, ((0, d64), stride_b_kv)), + stride=( + k_tensor.stride[4], + cute.assume(k_tensor.stride[1], divby=64), + ( + (0, k_tensor.stride[2]), + 0 if cum_seqlen_k is not None else cute.assume(k_tensor.stride[0], divby=64), + ), + ), ) kt = cute.make_tensor(k_tensor.iterator, kt_layout) # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast v_layout = cute.make_layout( (s_k_total, d, ((h_r, h_k), b)), - stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)), + stride=( + cute.assume(v_tensor.stride[1], divby=64), + v_tensor.stride[4], + ( + (0, v_tensor.stride[2]), + 0 if cum_seqlen_k is not None else cute.assume(v_tensor.stride[0], divby=64), + ), + ), ) v = cute.make_tensor(v_tensor.iterator, v_layout) # (s, ((h_r, h_k), b)) @@ -242,7 +274,14 @@ def __call__( # (s, d, ((h_r, h_k), b)) dq_layout = cute.make_layout( (s_q_total, d, ((h_r, h_k), b)), - stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)), + stride=( + cute.assume(dq_tensor.stride[1], divby=64), + dq_tensor.stride[4], + ( + (dq_tensor.stride[3], dq_tensor.stride[2]), + 0 if cum_seqlen_q is not None else cute.assume(dq_tensor.stride[0], divby=64), + ), + ), ) dq = cute.make_tensor(dq_tensor.iterator, dq_layout) From ab66326aaa4fe3529fbc00f3156f3a762dd3141b Mon Sep 17 00:00:00 2001 From: jayhshah Date: Fri, 8 May 2026 08:36:26 -0700 Subject: [PATCH 06/21] fix incorrect calculation of n_block global max for bwd deterministic (#2549) --- flash_attn/cute/block_info.py | 6 +++--- flash_attn/cute/flash_bwd_sm100.py | 7 +------ 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index 422da2b66a0..35bb4365ff6 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -143,8 +143,8 @@ def get_n_block_max_for_m_block( self, seqlen_info: SeqlenInfoQK, m_block: Int32, - n_block_global_max: Int32, ) -> Int32: + n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) if const_expr(self.is_causal or self.window_size_right is not None): m_idx_max = (m_block + 1) * self.tile_m if const_expr(self.qhead_per_kvhead_packgqa > 1): @@ -152,5 +152,5 @@ def get_n_block_max_for_m_block( n_idx_right = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q if const_expr(self.window_size_right is not None): n_idx_right += self.window_size_right - return min(n_block_global_max, cute.ceil_div(n_idx_right, self.tile_n)) - return n_block_global_max + n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.tile_n)) + return n_block_max diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 9184ddeb029..11db2dab563 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -3432,13 +3432,10 @@ def _dq_semaphore_lock_value( seqlen, m_block: Int32, n_block: Int32, - n_block_global_max: Int32, ) -> Int32: lock_value = n_block if const_expr(self.spt): - n_block_max_for_m_block = block_info.get_n_block_max_for_m_block( - seqlen, m_block, n_block_global_max - ) + n_block_max_for_m_block = block_info.get_n_block_max_for_m_block(seqlen, m_block) lock_value = n_block_max_for_m_block - 1 - n_block if const_expr(self.use_block_sparsity): assert blocksparse_tensors is not None @@ -3528,7 +3525,6 @@ def dQacc_reduce( mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] delay_semaphore_release = not self.tile_hdim == 192 and not self.use_block_sparsity - n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n) curr_q_cnt = Int32(0) curr_q_idx = None @@ -3627,7 +3623,6 @@ def dQacc_reduce( seqlen, m_block, n_block_cta_group, - n_block_global_max, ) barrier.wait_eq( mdQ_semaphore_cur[(m_block, None)].iterator, From 9bad4bec7326ad28edb5516b8878fd283f8991c0 Mon Sep 17 00:00:00 2001 From: liangel-02 Date: Tue, 12 May 2026 17:05:52 -0400 Subject: [PATCH 07/21] fix varlen w/ paging split kv bug (#2550) --- csrc/flash_attn/flash_api.cpp | 5 ++- tests/test_flash_attn.py | 58 +++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 70270f40fff..ca974949740 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -698,11 +698,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s params.page_block_size = page_block_size; // Keep references to these tensors to extend their lifetime at::Tensor softmax_lse_accum, out_accum; - if (paged_KV || seqlenq_ngroups_swapped) { + if (seqlenq_ngroups_swapped) { std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads, head_size, max_seqlen_k, max_seqlen_q, head_size_rounded, p_dropout, num_splits, get_num_sm(get_current_device()), opts); + } else if (paged_KV) { + TORCH_CHECK(num_splits <= 1, "num_splits > 1 is not supported for varlen paged KV"); + params.num_splits = num_splits; } if (leftpad_k_.has_value()) { diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index d5bb6ba8531..0589d1b2cd9 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -2523,3 +2523,61 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert torch.equal(dq, dq0) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_flash_attn_varlen_paged_kv_num_splits(dtype): + """Passing num_splits=0 explicitly should be bitwise identical to not passing it (default).""" + from flash_attn.flash_attn_interface import _flash_attn_varlen_forward + + device = "cuda" + num_heads, num_heads_k, head_dim = 4, 2, 64 + page_block_size = 256 + scale = head_dim ** -0.5 + + batch_size = 2 + kv_lens = [512, 1024] + max_seqlen_k = max(kv_lens) + + max_blocks_per_seq = max( + (s + page_block_size - 1) // page_block_size for s in kv_lens + ) + total_blocks = batch_size * max_blocks_per_seq + k_cache = torch.randn( + total_blocks, page_block_size, num_heads_k, head_dim, + device=device, dtype=dtype, + ) + v_cache = torch.randn( + total_blocks, page_block_size, num_heads_k, head_dim, + device=device, dtype=dtype, + ) + + block_table = rearrange( + torch.randperm(total_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + + q = torch.randn(batch_size, num_heads, head_dim, device=device, dtype=dtype) + cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=device) + seqused_k = torch.tensor(kv_lens, dtype=torch.int32, device=device) + cu_seqlens_k = torch.nn.functional.pad(seqused_k.cumsum(0), (1, 0)).to(torch.int32) + + fwd_kwargs = dict( + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=1, max_seqlen_k=max_seqlen_k, + dropout_p=0.0, softmax_scale=scale, + causal=False, window_size_left=-1, window_size_right=0, + block_table=block_table, seqused_k=seqused_k, + ) + + out_default = _flash_attn_varlen_forward(q, k_cache, v_cache, **fwd_kwargs)[0] + out_explicit = _flash_attn_varlen_forward(q, k_cache, v_cache, **fwd_kwargs, num_splits=0)[0] + + assert not out_default.isnan().any(), "default num_splits produced NaN" + assert torch.equal(out_default, out_explicit), ( + f"default vs num_splits=0 differ: max diff {(out_default - out_explicit).abs().max().item()}" + ) + + with pytest.raises(RuntimeError, match="num_splits > 1 is not supported"): + _flash_attn_varlen_forward(q, k_cache, v_cache, **fwd_kwargs, num_splits=2) From 484b9813f57814023387e91aab0364beaad096c8 Mon Sep 17 00:00:00 2001 From: Shivam Sharma <34232110+shivam2199@users.noreply.github.com> Date: Thu, 14 May 2026 02:41:46 +0530 Subject: [PATCH 08/21] Fix ZeroDivisionError in num_splits_heuristic for empty Q workloads (#2515) num_splits_heuristic divides num_SMs by total_mblocks, which collapses to 0 when seqlen_q == 0 or batch_size == 0 (e.g. CUDA graph padding or empty microbatches). The existing seqlen_k == 0 early-exit in _flash_attn_fwd does not cover these cases. - Extend the early-exit to also cover total_q == 0, using the same zero-output / -inf-LSE contract. total_q is batch_size * seqlen_q (dense) or q.shape[0] (varlen), so a single predicate handles both code paths. - Add a defensive total_mblocks == 0 guard inside num_splits_heuristic itself so the function is safe in isolation. - Add regression tests covering dense (batch=0, seqlen_q=0) and varlen (total_q=0) paths under both causal and non-causal masks. Fixes #2503. --- flash_attn/cute/interface.py | 7 ++- tests/cute/test_flash_attn.py | 100 ++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 0bd4651190c..45354e67559 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -256,6 +256,11 @@ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): # If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. if num_n_blocks <= 4: return 1 + # Avoid ZeroDivisionError when batch_size or seqlen_q is 0. The empty-Q + # early-exit in _flash_attn_fwd handles correctness for those shapes; this + # guard just keeps the heuristic safe if called in other contexts. + if total_mblocks == 0: + return 1 # NOTE: We should revisit this heuristic after persistence is supported for split KV. # Sometimes, it's ideal to over-schedule splits for better efficiency. @@ -455,7 +460,7 @@ def _flash_attn_fwd( elif lse is not None: _validate_tensor(lse, "lse", lse_shape, torch.float32, device) - if seqlen_k == 0: + if seqlen_k == 0 or total_q == 0: out.zero_() if lse is not None: lse.fill_(float("-inf")) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 21ed3a48d57..2ebf338598c 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -2744,3 +2744,103 @@ def test_flash_attn_seqlen_k_zero(seqlen_q, d, causal): if lse is not None: assert torch.all(torch.isinf(lse) & (lse < 0)).item(), \ f"Expected all -inf LSE when seqlen_k=0, got: {lse}" + + +# --------------------------------------------------------------------------- +# Regression test (#2503): empty Q workload must not crash with ZeroDivisionError +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "shape", + [ + # (batch_size, seqlen_q) — the product drives num_splits_heuristic's divisor. + (2, 0), # seqlen_q == 0 (common in CUDA graph padding) + (0, 64), # batch_size == 0 (empty microbatch) + ], +) +@pytest.mark.parametrize("causal", [False, True]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_empty_q_dense(shape, causal): + """Dense (non-varlen) path must not raise ZeroDivisionError when Q is empty. + + num_splits_heuristic divides num_SMs by total_mblocks = batch_size * num_head_kv + * num_m_blocks. When seqlen_q == 0 or batch_size == 0, total_mblocks collapses + to 0 and the division crashes. The existing seqlen_k == 0 early-exit does not + cover this case. Regression for + https://github.com/Dao-AILab/flash-attention/issues/2503. + """ + batch_size, seqlen_q = shape + device = "cuda" + dtype = torch.bfloat16 + d = 128 + nheads = 16 + nheads_kv = 16 + # Pick seqlen_k large enough that num_n_blocks > 4 so the heuristic's own + # `num_n_blocks <= 4` early-return does not mask the bug. + seqlen_k = 4096 + + torch.manual_seed(0) + + q = torch.empty(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + + out, lse = flash_attn_func(q, k, v, causal=causal) + + if is_fake_mode(): + return + + assert out.shape == (batch_size, seqlen_q, nheads, d), \ + f"Unexpected output shape: {out.shape}" + # With zero elements these are vacuously true, but we still want to assert + # the function returned cleanly rather than erroring out. + assert out.numel() == 0 + if lse is not None: + assert lse.numel() == 0 + + +@pytest.mark.parametrize("causal", [False, True]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_empty_q_varlen(causal): + """Varlen path must not raise ZeroDivisionError when total_q == 0. + + Parallels test_flash_attn_empty_q_dense for the cu_seqlens_q path, where + total_q = q.shape[0] feeds into total_mblocks. Regression for + https://github.com/Dao-AILab/flash-attention/issues/2503. + """ + device = "cuda" + dtype = torch.bfloat16 + d = 128 + nheads = 16 + nheads_kv = 16 + batch_size = 2 + seqlen_k_per_batch = 2048 + + torch.manual_seed(0) + + # All zero-length Q sequences — total_q == 0 while cu_seqlens_q is well-formed. + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_k = torch.tensor( + [0, seqlen_k_per_batch, 2 * seqlen_k_per_batch], + dtype=torch.int32, device=device, + ) + total_k = int(cu_seqlens_k[-1].item()) + + q = torch.empty(0, nheads, d, device=device, dtype=dtype) + k = torch.randn(total_k, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(total_k, nheads_kv, d, device=device, dtype=dtype) + + out, lse = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=0, max_seqlen_k=seqlen_k_per_batch, + causal=causal, + ) + + if is_fake_mode(): + return + + assert out.shape == (0, nheads, d), f"Unexpected output shape: {out.shape}" + assert out.numel() == 0 + if lse is not None: + assert lse.numel() == 0 From 9cee95fd92ab49f585b57ad49bec5e4d67b5150c Mon Sep 17 00:00:00 2001 From: geruome <85235464+geruome@users.noreply.github.com> Date: Thu, 14 May 2026 07:03:45 +0800 Subject: [PATCH 09/21] [Cute, flex, sm90] fix sm90 flex (#2563) --- flash_attn/cute/flash_fwd_sm90.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index bbbe7a063d9..23f92181166 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -888,6 +888,7 @@ def load( batch_idx, head_idx, m_block, + seqlen, kv_producer_state, tma_load_K_fn, tma_load_V_fn, From 0409f9adcbdebff6cc19eb95f370d40e896980bc Mon Sep 17 00:00:00 2001 From: Reuben Stern <107093092+reubenconducts@users.noreply.github.com> Date: Thu, 14 May 2026 13:34:24 -0400 Subject: [PATCH 10/21] split out varlen batch search into utils (#2556) * split out varlen batch search into utils * more descriptive name --- flash_attn/cute/compute_block_sparsity.py | 18 +++++++----------- flash_attn/cute/utils.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py index 69ef3c77619..777d3613eb1 100644 --- a/flash_attn/cute/compute_block_sparsity.py +++ b/flash_attn/cute/compute_block_sparsity.py @@ -18,7 +18,12 @@ get_aux_tensor_metadata, to_cute_aux_tensor, ) -from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar +from flash_attn.cute.utils import ( + hash_callable, + scalar_to_ssa, + ssa_to_scalar, + get_batch_from_cu_tensor, +) from flash_attn.cute.seqlen_info import SeqlenInfoQK @@ -161,16 +166,7 @@ class SharedStorage: m_block, head_idx, batch_idx = cute.arch.block_idx() else: global_m_block, head_idx, _ = cute.arch.block_idx() - # Binary search over cu_total_m_blocks to find batch_idx - lo = Int32(0) - hi = batch_size - while lo < hi: - mid = (lo + hi) // 2 - if mCuTotalMBlocks[mid + 1] <= global_m_block: - lo = mid + 1 - else: - hi = mid - batch_idx = lo + batch_idx = get_batch_from_cu_tensor(global_m_block, mCuTotalMBlocks) m_block = global_m_block - mCuTotalMBlocks[batch_idx] seqlen = SeqlenInfoCls(batch_idx) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index ffabd34e398..c8398c9a78d 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -949,3 +949,20 @@ def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: def ssa_to_scalar(val): """Could inline but nice for reflecting the above api""" return val[0] + + +@cute.jit +def get_batch_from_cu_tensor(idx: Int32, cu_tensor: cute.Tensor) -> Int32: + """Binary search to determine batch from packed index in a cumulative tensor""" + batch_size = cute.size(cu_tensor) - 1 + lo = Int32(0) + hi = batch_size + + while lo < hi: + mid = (lo + hi) // 2 + if cu_tensor[mid + 1] <= idx: + lo = mid + 1 + else: + hi = mid + + return lo From 8a8b2f10ddca88fd46db406c3e143e1ab0af977f Mon Sep 17 00:00:00 2001 From: jayhshah Date: Sat, 16 May 2026 15:03:48 -0700 Subject: [PATCH 11/21] allow for zero length sequences in hdim 256 sm100 kernels (#2568) --- .../cute/sm100_hd256_2cta_fmha_backward.py | 3 + ...100_hd256_2cta_fmha_backward_dkdvkernel.py | 48 ++-- ...sm100_hd256_2cta_fmha_backward_dqkernel.py | 251 ++++++++++-------- .../cute/sm100_hd256_2cta_fmha_forward.py | 23 +- tests/cute/test_flash_attn.py | 6 - 5 files changed, 198 insertions(+), 133 deletions(-) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py index c07e3e94176..ecda0e273ad 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward.py @@ -21,6 +21,7 @@ from flash_attn.cute.sm100_hd256_2cta_fmha_backward_dkdvkernel import ( BlackwellFusedMultiHeadAttentionBackwardDKDVKernel, ) +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned def _as_bshkrd_tensor( @@ -251,6 +252,8 @@ def __call__( else: b = Q.shape[0] + Q, K, V, dQ, dK, dV, dO = [assume_tensor_aligned(t) for t in (Q, K, V, dQ, dK, dV, dO)] + Q = _as_bshkrd_tensor(Q, h_k, h_r, varlen) K = _as_bshkrd_tensor(K, h_k, 1, varlen) V = _as_bshkrd_tensor(V, h_k, 1, varlen) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py index 6c9db87458f..885ae336f5f 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py @@ -32,6 +32,7 @@ Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams, ) +import flash_attn.cute.copy_utils as fa_copy_utils LAYOUT_RANK_CONSTANT = 3 @@ -2811,13 +2812,11 @@ def epilogue_clear( dK.iterator + mdK_offset, cute.make_layout((K, self.tile_shape_dQ_K, HB), stride=dK.stride), ) - gdK = cute.local_tile( - mdK, (self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1]), (None, None, None) - ) + gdK = cute.local_tile(mdK, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) gdK = gdK[None, None, blk_coord_k, 0, blk_coord_batch] cdK = cute.domain_offset( (blk_coord_k * self.tile_shape_K, 0), - cute.make_identity_tensor((self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1])), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), ) mdV_offset = cute.assume(blk_offset[1] * dV.stride[0], divby=64) @@ -2825,24 +2824,41 @@ def epilogue_clear( dV.iterator + mdV_offset, cute.make_layout((K, self.tile_shape_dV_dO, HB), stride=dV.stride), ) - gdV = cute.local_tile( - mdV, (self.PdO_mma_tiler[0], self.PdO_mma_tiler[1]), (None, None, None) - ) + gdV = cute.local_tile(mdV, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) gdV = gdV[None, None, blk_coord_k, 0, blk_coord_batch] cdV = cute.domain_offset( (blk_coord_k * self.tile_shape_K, 0), - cute.make_identity_tensor((self.PdO_mma_tiler[0], self.PdO_mma_tiler[1])), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), ) - for i in cutlass.range(tidx * 8, cute.size(gdK), block_dim_x * 8): - if cute.elem_less(cdK[i], cute.select(problem_shape, mode=[1, 2])): - gdK_i = cute.make_tensor(gdK.iterator + cute.assume(i, divby=8), (8)) - gdK_i.fill(0) + num_zero_epi_threads = 256 + + tiled_copy_r2g = fa_copy_utils.tiled_copy_2d( + dK.element_type, self.cta_tiler[2], num_zero_epi_threads + ) + + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + + tRG_gdK = thr_copy_r2g.partition_D(gdK) + tRG_cdK = thr_copy_r2g.partition_D(cdK) + tRG_gdV = thr_copy_r2g.partition_D(gdV) + tRG_cdV = thr_copy_r2g.partition_D(cdV) + + zero_frg = cute.make_rmem_tensor_like(tRG_gdK[None, 0, None]) + zero_frg.fill(dK.element_type(0.0)) + + # check we don't need zero fragment duplication + V_frg_size = cute.size(tRG_gdV[None, 0, None]) + assert cute.size(zero_frg) == V_frg_size + + if tidx < num_zero_epi_threads: + for n in cutlass.range(cute.size(tRG_gdK.shape[1]), unroll_full=True): + if cute.elem_less(tRG_cdK[0, n, 0][0], problem_shape[1]): + cute.copy(tiled_copy_r2g, zero_frg, tRG_gdK[None, n, None]) - for i in cutlass.range(tidx * 8, cute.size(gdV), block_dim_x * 8): - if cute.elem_less(cdV[i], cute.select(problem_shape, mode=[1, 2])): - gdV_i = cute.make_tensor(gdV.iterator + cute.assume(i, divby=8), (8)) - gdV_i.fill(0) + for n in cutlass.range(cute.size(tRG_gdV.shape[1]), unroll_full=True): + if cute.elem_less(tRG_cdV[0, n, 0][0], problem_shape[1]): + cute.copy(tiled_copy_r2g, zero_frg, tRG_gdV[None, n, None]) @cute.jit def epilogue( diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py index a95c9677b65..25d6a91de70 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py @@ -29,6 +29,7 @@ Sm100FusedMask as FusedMask, ) from flash_attn.cute.tile_scheduler import SM100_TMEM_CAPACITY_COLUMNS +import flash_attn.cute.copy_utils as fa_copy_utils class BlackwellFusedMultiHeadAttentionBackwardDQKernel: @@ -924,36 +925,45 @@ def kernel( curr_block_coord[1], curr_block_coord[2], ) - continue_cond = False batch_coord = curr_block_coord[2][1] seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] cuseqlen_q = Int32(0) cuseqlen_k = Int32(0) - block_offset = ( - Int32(0), - Int32(0), - Int32(0), - ((Int32(0), Int32(0)), Int32(0)), - ) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k + + if has_work: block_offset = ( cuseqlen_q, cuseqlen_k, Int32(0), ((Int32(0), Int32(0)), Int32(0)), ) - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( - self.qk_mma_tiler[0], - mma_block_coord[0], - seqlen_q, - ) - if not continue_cond: mQ_qdl_ = cute.domain_offset(cute.select(block_offset, mode=[0, 2, 3]), mQ_qdl) mK_kdl_ = cute.domain_offset(cute.select(block_offset, mode=[1, 2, 3]), mK_kdl) mdO_qdl_ = cute.domain_offset( @@ -1057,18 +1067,6 @@ def kernel( # ((atom_v, rest_v), RestN, RestK) tKTgKT = tKgK_dkl[None, None, None, mma_block_coord[2]] - seqlen_kv_loop_start, seqlen_kv_loop_steps = ( - FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) - ) # LSE lse_handle = load_lse_producer.acquire_and_advance() # 32 threads loading 128 values of 32b each @@ -1197,6 +1195,9 @@ def kernel( if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx mma_block_coord = ( @@ -1204,41 +1205,37 @@ def kernel( curr_block_coord[1], curr_block_coord[2], ) - continue_cond = False seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] batch_coord = curr_block_coord[2][1] + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) - - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k - - seqlen_kv_loop_start, seqlen_kv_loop_steps = ( - FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 + if has_work: # dq_handle = mma_dq_producer.acquire_and_advance() load_q_releaser = load_q_consumer.clone() load_do_releaser = load_do_consumer.clone() @@ -1836,33 +1833,35 @@ def kernel( curr_block_coord[2], ) batch_coord = curr_block_coord[2][1] - continue_cond = False seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] cuseqlen_q = Int32(0) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + is_valid_k = trip_count > 0 + has_work = is_valid_q and is_valid_k - start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) + if has_work: end_count = start_count + trip_count if cutlass.const_expr(self.use_semantic_trip_range): n_block_min_causal_local_mask, n_block_min_before_local_mask = ( @@ -1932,6 +1931,7 @@ def kernel( ) lse_handle.release() sum_odo_handle.release() + work_tile = tile_sched.advance_to_next_work() ds_mma_producer.tail() @@ -1952,61 +1952,75 @@ def kernel( # cute.printf("batch_coord={}", batch_coord) seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] - continue_cond = False cuseqlen_q = Int32(0) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + mdQ_qdl_eff = mdQ_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + block_offset_dQ = (cuseqlen_q,) + (None,) * 2 + mdQ_qdl_eff = cute.domain_offset(block_offset_dQ, mdQ_qdl) - mdQ_qdl_eff = mdQ_qdl - if cutlass.const_expr(cum_seqlen_q is not None): - block_offset_dQ = ( - cuseqlen_q, - Int32(0), - Int32(0), - ((Int32(0), Int32(0)), Int32(0)), - ) - mdQ_qdl_eff = cute.domain_offset( - cute.select(block_offset_dQ, mode=[0, 2, 3]), mdQ_qdl - ) + # (bM, bN, loopM, loopN, loopL) + gdQ_qdl = cute.flat_divide( + mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + cdQ_qdl = cute.flat_divide( + cute.make_identity_tensor(mdQ_qdl_eff.shape), + cute.select(self.dsk_block_tiler, mode=[0, 1]), + ) - # (bM, bN, loopM, loopN, loopL) - gdQ_qdl = cute.flat_divide( - mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) - ) - cdQ_qdl = cute.flat_divide( - cute.make_identity_tensor(mdQ_qdl_eff.shape), - cute.select(self.dsk_block_tiler, mode=[0, 1]), - ) + gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + gdQ_tma_staged = gdQ_staged - gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] - cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] - gdQ_tma_staged = gdQ_staged - if cutlass.const_expr(not varlen): - gdQ_tma_qdl = cute.flat_divide( - mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) - ) - gdQ_tma_staged = gdQ_tma_qdl[ - None, None, curr_block_coord[0], None, curr_block_coord[2] - ] + if cutlass.const_expr(not varlen): + gdQ_tma_qdl = cute.flat_divide( + mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + gdQ_tma_staged = gdQ_tma_qdl[ + None, None, curr_block_coord[0], None, curr_block_coord[2] + ] + if has_work: # dQ TMEM to GMEM mma_dq_consumer = self.dQ_epilogue( - (seqlen_q, cuseqlen_q, mQ_qdl.shape[0], batch_coord), + seqlen_q, (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged), self.epi_tile, (tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen), ) + else: + self.dQ_epilogue_write_zero( + seqlen_q, + gdQ_staged, + cdQ_staged, + ) + work_tile = tile_sched.advance_to_next_work() # NOTE: tmem.free() moved to kernel end to enable cluster-wide sync @@ -2181,12 +2195,11 @@ def compute_step( @cute.jit def dQ_epilogue( self, - value_args: Tuple, + seqlen_q: int, dq_args: Tuple, epi_tile: cute.Tile, tma_args: Tuple, ) -> Tuple[pipeline.PipelineConsumer, pipeline.PipelineProducer]: - seqlen_q, cuseqlen_q, total_q, batch_coord = value_args (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged) = dq_args tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen = tma_args dq_handle = mma_dq_consumer.wait_and_advance() @@ -2274,3 +2287,31 @@ def dQ_epilogue( cute.autovec_copy(tSMrdQ, tTMEM_LOADgdQ_i) dq_handle.release() return mma_dq_consumer + + @cute.jit + def dQ_epilogue_write_zero( + self, + seqlen_q, + gdQ_staged, + cdQ_staged, + ): + num_epi_threads = self.threads_per_warp * len(self.epilogue_warp_ids) + tidx = cute.arch.thread_idx()[0] % num_epi_threads + + tiled_copy_r2g = fa_copy_utils.tiled_copy_2d( + self.dq_dtype, cute.size(gdQ_staged.shape[1]), num_epi_threads + ) + + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + tdQgdQ_staged = thr_copy_r2g.partition_D(gdQ_staged) + tdQcdQ_staged = thr_copy_r2g.partition_D(cdQ_staged) + + tdQrdQ = cute.make_rmem_tensor_like(tdQgdQ_staged[None, 0, None, 0]) + tdQrdQ.fill(self.dq_dtype(0.0)) + + for iter in cutlass.range(self.iterations_dsk, unroll_full=True): + tdQgdQ = tdQgdQ_staged[None, None, None, iter] + tdQcdQ = tdQcdQ_staged[None, None, None, iter] + for m in cutlass.range(cute.size(tdQgdQ.shape[1]), unroll_full=True): + if cute.elem_less(tdQcdQ[0, m, 0][0], seqlen_q): + cute.copy(tiled_copy_r2g, tdQrdQ, tdQgdQ[None, m, None]) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py index 28087125f47..379cebc1905 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py @@ -1030,6 +1030,9 @@ def kernel( if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx mma_block_coord = ( @@ -1071,10 +1074,6 @@ def kernel( ) seqlen_kv_loop_end = seqlen_kv_loop_start + seqlen_kv_loop_steps - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 load_q_releaser = load_q_consumer.clone() pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) if seqlen_kv_loop_steps > 1: @@ -1323,6 +1322,11 @@ def kernel( window_size_right, ) end_count = start_count + trip_count + # require at least one softmax iteration for zero trip_count case; + # rely on masking this iteration for correctness + if end_count <= start_count: + start_count = 0 + end_count = 1 if cutlass.const_expr(self.use_semantic_trip_range): n_block_min_causal_local_mask, n_block_min_before_local_mask = ( FusedMask.get_trip_mask_bounds_via_block_info( @@ -1349,6 +1353,7 @@ def kernel( need_apply_mask = ( step >= n_block_min_causal_local_mask or step < n_block_min_before_local_mask + or step == end_count - 1 ) else: # Residual path only needs seqlen masking on the last K tile. @@ -1797,7 +1802,8 @@ def correction_epilog( row_sum = sSum[thread_idx] cute.arch.fence_view_async_shared() sum_handle.release() - scale = scale_output / row_sum + row_sum_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + scale = scale_output / row_sum if not row_sum_is_zero_or_nan else 0.0 o_handle = mma_o_consumer.wait_and_advance() for iter in cutlass.range(self.iterations_pv): gO = gO_staged[None, None, iter] @@ -1855,6 +1861,7 @@ def store_sum_max( sSum[thread_idx] = row_sum cute.arch.fence_view_async_shared() sum_handle.commit() + row_sum_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum if cutlass.const_expr(mLSE is not None): q_idx = current_block_coord[0] * self.cta_tiler[0] + tidx @@ -1863,7 +1870,11 @@ def store_sum_max( if cutlass.const_expr(cum_seqlen_q is not None) else current_block_coord[2] ) - lse_value = scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + lse_value = ( + scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + if not row_sum_is_zero_or_nan + else -Float32.inf + ) if cute.elem_less(q_idx, seqlen_q): global_q_idx = ( q_idx + cuseqlen_q if cutlass.const_expr(cum_seqlen_q is not None) else q_idx diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 2ebf338598c..764d7123681 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -155,8 +155,6 @@ def test_flash_attn_output( pytest.skip("SM100 head_dim=256 2CTA kernel does not support softcap yet") if deterministic: pytest.skip("SM100 head_dim=256 2CTA kernel does not support deterministic mode yet") - if causal and seqlen_q > seqlen_k: - pytest.skip("SM100 head_dim=256 2CTA kernel does not support causal attention with seqlen_q > seqlen_k yet") device = "cuda" # set seed seed = 0 @@ -551,10 +549,6 @@ def test_flash_attn_varlen_output( pytest.skip("SM100 head_dim=256 2CTA kernel does not support softcap yet") if deterministic: pytest.skip("SM100 head_dim=256 2CTA kernel does not support deterministic mode yet") - if causal and seqlen_q > seqlen_k: - pytest.skip("SM100 head_dim=256 2CTA kernel does not support causal attention with seqlen_q > seqlen_k yet") - if zero_lengths_q or zero_lengths_k: - pytest.skip("SM100 head_dim=256 2CTA kernel does not support zero-length sequences yet") if not unpad_q or not unpad_kv: pytest.skip("SM100 head_dim=256 2CTA kernel does not support seqused_q/seqused_k mode yet (requires unpad_q=True and unpad_kv=True)") if ( From 4178915405522cb161ece80005e882c7465693b7 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 19 May 2026 16:41:43 -0700 Subject: [PATCH 12/21] Enable split-kv for blocksparse tensors (#2536) stack-info: PR: https://github.com/Dao-AILab/flash-attention/pull/2536, branch: drisspg/stack/38 --- flash_attn/cute/block_sparse_utils.py | 95 ++++++++++------- flash_attn/cute/flash_fwd_sm100.py | 15 ++- flash_attn/cute/interface.py | 4 - tests/cute/test_mask_mod.py | 140 ++++++++++++++++++++++++++ tests/cute/test_mask_mod_varlen.py | 139 +++++++++++++++++++++++++ 5 files changed, 354 insertions(+), 39 deletions(-) diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index def4f088d92..fb131745b3b 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -550,10 +550,20 @@ def consume_block_sparse_loads( return kv_consumer_state, O_should_accumulate, processed_any +@cute.jit +def split_block_range(block_count, split_idx: Int32, num_splits: Int32): + """Return the half-open block-list range assigned to one SplitKV partition.""" + blocks_per_split = cute.ceil_div(block_count, num_splits) + block_begin = cutlass.min(split_idx * blocks_per_split, block_count) + block_end = cutlass.min(block_begin + blocks_per_split, block_count) + return block_begin, block_end + + @cute.jit def load_block_list_sm100( block_indices: cute.Tensor, - block_count, + block_begin, + block_end, load_q_with_first: cutlass.Constexpr, q_stage: cutlass.Constexpr, kv_producer_state, @@ -563,9 +573,10 @@ def load_block_list_sm100( pipeline_kv, ): """SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count).""" + block_count = block_end - block_begin if block_count > 0: # First iteration: load Q alongside K if requested - n_block_first = block_indices[block_count - 1] + n_block_first = block_indices[block_end - 1] if const_expr(load_q_with_first): # SM100 loads Q0 and optionally Q1 @@ -582,7 +593,7 @@ def load_block_list_sm100( # Remaining blocks for offset in cutlass.range(1, block_count): - n_block = block_indices[block_count - 1 - offset] + n_block = block_indices[block_end - 1 - offset] load_K(block=n_block, producer_state=kv_producer_state, page_idx=None) kv_producer_state.advance() load_V(block=n_block, producer_state=kv_producer_state, page_idx=None) @@ -598,7 +609,9 @@ def produce_block_sparse_loads_sm100( batch_idx, head_idx, m_block, - seqlen_info, + seqlen_info: SeqlenInfoQK, + split_idx: Int32, + num_splits: Int32, kv_producer_state, load_Q, load_K, @@ -633,8 +646,10 @@ def produce_block_sparse_loads_sm100( seqlen_info, ) - mask_empty = curr_mask_block_cnt == 0 - full_empty = curr_full_block_cnt == 0 + mask_begin, mask_end = split_block_range(curr_mask_block_cnt, split_idx, num_splits) + full_begin, full_end = split_block_range(curr_full_block_cnt, split_idx, num_splits) + mask_empty = mask_begin == mask_end + full_empty = full_begin == full_end q_phase_flipped = False @@ -642,7 +657,8 @@ def produce_block_sparse_loads_sm100( # No masked blocks: process full list with Q loading kv_producer_state = load_block_list_sm100( curr_full_block_idx, - curr_full_block_cnt, + full_begin, + full_end, load_q_with_first=True, q_stage=q_stage, kv_producer_state=kv_producer_state, @@ -656,7 +672,8 @@ def produce_block_sparse_loads_sm100( # Process masked blocks with Q loading kv_producer_state = load_block_list_sm100( curr_mask_block_idx, - curr_mask_block_cnt, + mask_begin, + mask_end, load_q_with_first=True, q_stage=q_stage, kv_producer_state=kv_producer_state, @@ -671,7 +688,8 @@ def produce_block_sparse_loads_sm100( # Process full blocks without Q loading kv_producer_state = load_block_list_sm100( curr_full_block_idx, - curr_full_block_cnt, + full_begin, + full_end, load_q_with_first=False, q_stage=q_stage, kv_producer_state=kv_producer_state, @@ -693,26 +711,29 @@ def get_total_block_count( batch_idx, head_idx, m_block, + split_idx: Int32, + num_splits: Int32, qhead_per_kvhead: cutlass.Constexpr, q_subtile_factor: cutlass.Constexpr, seqlen_info: SeqlenInfoQK, ): m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, _, full_block_cnt, *_ = blocksparse_tensors - - if const_expr(len(mask_block_cnt.shape) == 2): - # varlen path: tensors are [num_heads, total_m_block] - curr_m = seqlen_info.m_block_offset + m_block_sparse - total = mask_block_cnt[head_idx, curr_m] - if const_expr(full_block_cnt is not None): - total += full_block_cnt[head_idx, curr_m] - else: - # non-varlen: tensors are [batch, num_heads, m_block] - total = mask_block_cnt[batch_idx, head_idx, m_block_sparse] - if const_expr(full_block_cnt is not None): - total += full_block_cnt[batch_idx, head_idx, m_block_sparse] + ( + curr_mask_block_cnt, + _, + curr_full_block_cnt, + _, + ) = get_curr_blocksparse_tensors( + batch_idx, + head_idx, + m_block_sparse, + blocksparse_tensors, + seqlen_info, + ) - return total + mask_begin, mask_end = split_block_range(curr_mask_block_cnt, split_idx, num_splits) + full_begin, full_end = split_block_range(curr_full_block_cnt, split_idx, num_splits) + return mask_end - mask_begin + full_end - full_begin @cute.jit @@ -839,7 +860,9 @@ def softmax_block_sparse_sm100( batch_idx, head_idx, m_block, - seqlen_info, + seqlen_info: SeqlenInfoQK, + split_idx: Int32, + num_splits: Int32, softmax_step: Callable, mask_fn: Callable, mask_fn_none: Callable, @@ -870,13 +893,17 @@ def softmax_block_sparse_sm100( seqlen_info, ) - total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt + mask_begin, mask_end = split_block_range(curr_mask_block_cnt, split_idx, num_splits) + full_begin, full_end = split_block_range(curr_full_block_cnt, split_idx, num_splits) + split_mask_block_cnt = mask_end - mask_begin + split_full_block_cnt = full_end - full_begin + total_block_cnt = split_mask_block_cnt + split_full_block_cnt if total_block_cnt == 0: sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx) else: - if curr_mask_block_cnt > 0: - mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + if split_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[mask_end - 1] ( mma_si_consumer_phase, si_corr_producer_phase, @@ -889,8 +916,8 @@ def softmax_block_sparse_sm100( is_first=True, mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary), ) - for i in cutlass.range(1, curr_mask_block_cnt): - mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + for i in cutlass.range(1, split_mask_block_cnt): + mask_n_block = curr_mask_block_idx[mask_end - 1 - i] ( mma_si_consumer_phase, si_corr_producer_phase, @@ -903,9 +930,9 @@ def softmax_block_sparse_sm100( mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary), ) - if curr_full_block_cnt > 0: - full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] - if curr_mask_block_cnt == 0: + if split_full_block_cnt > 0: + full_n_block = curr_full_block_idx[full_end - 1] + if split_mask_block_cnt == 0: ( mma_si_consumer_phase, si_corr_producer_phase, @@ -935,8 +962,8 @@ def softmax_block_sparse_sm100( mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary ), ) - for i in cutlass.range(1, curr_full_block_cnt): - full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + for i in cutlass.range(1, split_full_block_cnt): + full_n_block = curr_full_block_idx[full_end - 1 - i] ( mma_si_consumer_phase, si_corr_producer_phase, diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 4d38174c2c8..55a92f690bd 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1211,6 +1211,7 @@ def kernel( num_splits, SeqlenInfoCls, mma_tile_coord_v, + blocksparse_tensors=blocksparse_tensors, tile_scheduler=tile_scheduler, ) @@ -1499,6 +1500,8 @@ def load( head_idx, m_block, seqlen, + split_idx, + num_splits, kv_producer_state, load_Q, load_K, @@ -1646,6 +1649,8 @@ def mma( batch_idx, head_idx, m_block, + split_idx, + num_splits, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, seqlen_info=seqlen, @@ -2013,6 +2018,8 @@ def softmax_loop( batch_idx, head_idx, m_block, + split_idx, + num_splits, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, seqlen_info=seqlen, @@ -2072,6 +2079,8 @@ def softmax_loop( head_idx, m_block, seqlen, + split_idx, + num_splits, softmax_step, mask_fn, mask_fn_none, @@ -2427,6 +2436,8 @@ def correction_loop( batch_idx, head_idx, m_block, + split_idx, + num_splits, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, seqlen_info=seqlen, @@ -2841,6 +2852,7 @@ def epilogue_s2g( num_splits: int, SeqlenInfoCls: Callable, mma_tile_coord_v: Int32 = 0, + blocksparse_tensors: Optional[BlockSparseTensors] = None, tile_scheduler=None, ): epi_consumer_phase = Int32(0) @@ -2849,8 +2861,9 @@ def epilogue_s2g( m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + has_work = const_expr(self.use_block_sparsity or not self.is_split_kv) or n_block_min < n_block_max - if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + if has_work: if const_expr(self.is_split_kv): mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 45354e67559..189ae1faca7 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -609,10 +609,6 @@ def _flash_attn_fwd( head_dim_idx = 0 if block_sparse_tensors.mask_block_cnt.ndim == 2 else 1 if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[head_dim_idx] != 1: pack_gqa = False - if is_split_kv: - raise NotImplementedError( - "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." - ) if cu_seqlens_q is not None: assert block_sparse_tensors.cu_total_m_blocks is not None, ( "Varlen block sparsity requires block_sparse_tensors.cu_total_m_blocks." diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 9e4c440fe0c..52f78e8d26d 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -2258,6 +2258,146 @@ def doc_mask(b, h, q_idx, kv_idx): _run_write_order_test(doc_mask, seqlen_q, seqlen_k, block_size=128, B=B, H=H, spt=spt) +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="SM100/SM110 SplitKV block sparse forward only") +def test_block_sparse_splitkv_matches_unsplit(): + torch.manual_seed(123) + batch_size = 1 + nheads = 4 + seqlen = 2048 + headdim = 64 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + sparse_tile_m = 2 * tile_m + + mask_mod_cute, mask_mod_flex = get_mask_pair("causal", seqlen_q=seqlen, seqlen_k=seqlen) + tensors = create_tensors( + batch_size, seqlen, seqlen, nheads, nheads, headdim, headdim, dtype + ) + + bm = create_block_mask( + mask_mod_flex, + batch_size, + nheads, + seqlen, + seqlen, + device="cuda", + BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + (_, _, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_) = bm.as_tuple() + block_sparse_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(sparse_tile_m, tile_n), + ) + + out_unsplit, lse_unsplit = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"].clone(), + lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_fwd, + num_splits=1, + return_lse=True, + ) + out_split, lse_split = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"].clone(), + lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_fwd, + num_splits=3, + return_lse=True, + ) + + out_ref = compute_reference_flex_attn(tensors, mask_mod_flex, block_size=(sparse_tile_m, tile_n)) + out_ref_fp32 = compute_reference_flex_attn( + {name: tensor.float() for name, tensor in tensors.items()}, + mask_mod_flex, + block_size=(sparse_tile_m, tile_n), + ) + + assert_fwd_matches_reference(out_split, out_ref_fp32, out_ref) + assert torch.allclose(lse_split, lse_unsplit, atol=2e-3, rtol=2e-3) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="SM100/SM110 SplitKV block sparse forward only") +def test_block_sparse_splitkv_oversplit_sparse_blocks(): + torch.manual_seed(321) + batch_size = 1 + nheads = 4 + seqlen_q = 513 + seqlen_k = 1024 + headdim = 64 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + sparse_tile_m = 2 * tile_m + + mask_mod_cute, mask_mod_flex = get_mask_pair("block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k) + tensors = create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads, headdim, headdim, dtype + ) + bm = create_block_mask( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + (_, _, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_) = bm.as_tuple() + block_sparse_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(sparse_tile_m, tile_n), + ) + + out_unsplit, _ = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"].clone(), + lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_fwd, + num_splits=1, + return_lse=True, + ) + out_split, _ = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"].clone(), + lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_fwd, + num_splits=8, + return_lse=True, + ) + + assert not torch.isnan(out_split).any() + assert torch.isfinite(out_split).all() + assert torch.allclose(out_split, out_unsplit, atol=4e-3, rtol=4e-3) + + def test_compact_block_sparse_indices(): """Test that compact block sparse index tensors (idx.shape[3] < n_blocks) work correctly. diff --git a/tests/cute/test_mask_mod_varlen.py b/tests/cute/test_mask_mod_varlen.py index e935d4d4430..70d08763cb3 100644 --- a/tests/cute/test_mask_mod_varlen.py +++ b/tests/cute/test_mask_mod_varlen.py @@ -618,6 +618,18 @@ def test_varlen_global_masks(seqlens_q, seqlens_k, mask_name): # ============================================================================= +@cute.jit +def cute_all_true_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + return m_idx >= utils.scalar_to_ssa(0, cutlass.Int32) + + def _make_block_sparse_tensors( mask_mod, seqlens_q, @@ -691,6 +703,7 @@ def _run_fwd( seqused_k=None, block_sparse_tensors=None, aux_tensors=None, + num_splits=1, ): out = torch.empty_like(q) return _flash_attn_fwd( @@ -718,6 +731,7 @@ def _run_fwd( block_sparse_tensors=block_sparse_tensors, return_lse=False, aux_tensors=aux_tensors, + num_splits=num_splits, )[0] @@ -891,5 +905,130 @@ def make_cu_seqlens(seqlens): ) +VARLEN_BLOCK_SPARSE_SPLITKV_SEQLENS = [ + ([128], [2048]), + ([96], [1536]), + ([128, 64], [2048, 1024]), + ([1, 128], [256, 2048]), +] + +VARLEN_BLOCK_SPARSE_SPLITKV_MASKS = [ + "all_true", + "causal", + "sliding_window", + "prefix_lm", + "block_diagonal", +] + + +def _get_splitkv_varlen_mask(mask_name, max_seqlen_q, max_seqlen_k): + match mask_name: + case "all_true": + return cute_all_true_mask + case "causal": + return get_mask_pair("causal", seqlen_q=max_seqlen_q, seqlen_k=max_seqlen_k)[0] + case "sliding_window": + return get_mask_pair( + "sliding_window", + seqlen_q=max_seqlen_q, + seqlen_k=max_seqlen_k, + window_size=512, + )[0] + case _: + return get_mask_pair(mask_name)[0] + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="SM100/SM110 SplitKV forward only") +@pytest.mark.parametrize("use_seqused_k", [False, True]) +@pytest.mark.parametrize("mask_name", VARLEN_BLOCK_SPARSE_SPLITKV_MASKS) +@pytest.mark.parametrize("seqlens_q,seqlens_k", VARLEN_BLOCK_SPARSE_SPLITKV_SEQLENS) +def test_varlen_block_sparse_splitkv_matches_unsplit(seqlens_q, seqlens_k, mask_name, use_seqused_k): + """Varlen block-sparse SplitKV should match the unsplit block-sparse path.""" + torch.manual_seed(123) + random.seed(123) + device = "cuda" + num_heads = 4 + head_dim = 128 + dtype = torch.bfloat16 + sparse_tile_m = 256 if sum(seqlens_q) > 128 else 128 + tile_n = 128 + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + + q = torch.randn(sum(seqlens_q), num_heads, head_dim, device=device, dtype=dtype) + cu_seqlens_q = torch.tensor( + [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + mask_mod = _get_splitkv_varlen_mask(mask_name, max_seqlen_q, max_seqlen_k) + + if use_seqused_k: + k = torch.randn( + len(seqlens_k), max_seqlen_k, num_heads, head_dim, device=device, dtype=dtype + ) + v = torch.randn_like(k) + cu_seqlens_k = None + seqused_k = torch.tensor(seqlens_k, dtype=torch.int32, device=device) + else: + k = torch.randn(sum(seqlens_k), num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn_like(k) + cu_seqlens_k = torch.tensor( + [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + seqused_k = None + + block_sparse_tensors = _make_block_sparse_tensors( + mask_mod=mask_mod, + seqlens_q=seqlens_q, + seqlens_k=seqlens_k, + num_heads=1, + tile_m=sparse_tile_m, + tile_n=tile_n, + device=device, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + ) + + out_unsplit = _run_fwd( + q, + k, + v, + mask_mod, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + block_sparse_tensors=block_sparse_tensors, + ) + out_split = _run_fwd( + q, + k, + v, + mask_mod, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + block_sparse_tensors=block_sparse_tensors, + num_splits=3, + ) + out_no_block_sparsity = _run_fwd( + q, + k, + v, + mask_mod, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + ) + + assert not torch.isnan(out_split).any(), "NaN in SplitKV block-sparse output" + assert torch.isfinite(out_split).all(), "Inf in SplitKV block-sparse output" + assert (out_unsplit - out_no_block_sparsity).abs().max().item() <= 0.01 + assert (out_split - out_no_block_sparsity).abs().max().item() <= 0.01 + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) From 0cb66b415844d85699b5e32d787a70782f5fe206 Mon Sep 17 00:00:00 2001 From: sryap <17482891+sryap@users.noreply.github.com> Date: Fri, 22 May 2026 15:49:08 -0700 Subject: [PATCH 13/21] Wrap mask contruction in a function for mask subclassing (#2584) Summary: Extract the inline `AttentionMask` construction in `FlashAttentionForwardSm100` and `FlashAttentionBackwardSm100` into an overridable `_generate_attention_mask_cls` method. This allows subclasses to inject a custom `AttentionMask` without modifying the base kernel code. For example, a custom attention kernel can override the mask to add a `causal_q_divisor` field for scaling the `row_idx` value. ``` class CustomAttentionMask(AttentionMask): causal_q_divisor: cutlass.Constexpr[int] = 1 @cute.jit def apply_mask_sm100(self, acc_S, m_block, n_block, ...): # Custom causal logic using causal_q_divisor row_idx = (tScS_t2r[0][0] + m_block * self.tile_m) // self.causal_q_divisor ... class CustomFlashAttentionForwardSm100(FlashAttentionForwardSm100): def __init__(self, *args, causal_q_divisor=1, **kwargs): super().__init__(*args, **kwargs) self.causal_q_divisor = causal_q_divisor def _generate_attention_mask_cls(self, window_size_left, window_size_right): return partial( CustomAttentionMask, self.m_block_size, self.n_block_size, window_size_left=window_size_left, window_size_right=window_size_right, bottom_right=self.is_bottom_right, causal_q_divisor=self.causal_q_divisor, ) ``` Test Plan: ``` $ pytest tests/cute/test_flash_attn_fast.py -v ================ 240 passed, 4139 warnings in 984.24s (0:16:24) ================ ``` Reviewers: Subscribers: Tasks: Tags: --- flash_attn/cute/flash_bwd_sm100.py | 19 ++++++++++++------- flash_attn/cute/flash_fwd_sm100.py | 21 ++++++++++++++------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 11db2dab563..81462e50afd 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -1010,6 +1010,16 @@ class SharedStorage: min_blocks_per_mp=1, ) + def _generate_attention_mask_cls(self, window_size_left, window_size_right): + return partial( + AttentionMask, + self.tile_m, + self.tile_n * self.cta_group_size, + swap_AB=True, + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + @cute.kernel def kernel( self, @@ -1413,13 +1423,8 @@ def kernel( ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) - AttentionMaskCls = partial( - AttentionMask, - self.tile_m, - self.tile_n * self.cta_group_size, - swap_AB=True, - window_size_left=window_size_left, - window_size_right=window_size_right, + AttentionMaskCls = self._generate_attention_mask_cls( + window_size_left, window_size_right ) # EMPTY # (15) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 55a92f690bd..dc99022c2a5 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -772,6 +772,18 @@ class SharedStorage: min_blocks_per_mp=1, ) + def _generate_attention_mask_cls(self, window_size_left, window_size_right): + return partial( + AttentionMask, + self.m_block_size, + self.n_block_size, + window_size_left=window_size_left, + window_size_right=window_size_right, + qhead_per_kvhead_packgqa=( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ), + ) + # GPU device kernel @cute.kernel def kernel( @@ -1060,13 +1072,8 @@ def kernel( blocksparse_tensors.cu_block_idx_offsets if blocksparse_tensors is not None else None ), ) - AttentionMaskCls = partial( - AttentionMask, - self.m_block_size, - self.n_block_size, - window_size_left=window_size_left, - window_size_right=window_size_right, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + AttentionMaskCls = self._generate_attention_mask_cls( + window_size_left, window_size_right ) # Cluster wait before tensor memory alloc pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) From 3da76cdb8aedd842c46511e5194f5f20cdd4cf6f Mon Sep 17 00:00:00 2001 From: Shivang <40532945+aw920h@users.noreply.github.com> Date: Sat, 23 May 2026 04:30:00 +0530 Subject: [PATCH 14/21] Build Fix: Update abi3 tag to cp310 and minimum python version to 3.10 (#2532) * Fix: Remove misleading py_limited_api=cp39 wheel tag for PyTorch extension * Implement dynamic ABI tagging for PyTorch versions Add dynamic ABI tag based on PyTorch version for correct and improved naming of the wheel. * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * Update Python version requirements based on torch metadata * Refactor setup.py for dynamic ABI and CUDA settings Refactor dynamic ABI tag and Python version requirements based on installed PyTorch version and streamline CUDA extension arguments. * Update CUDAExtension compile arguments Restored some accidentally removed content * Update setup.py * Updated setup.py minor fix: cleaned up the comments * Brought back Py_LIMITED_API flag to CUDA extension compilation * Minor fix * Update setup.py for Python version requirements Updated the wheel tag to cp310 and python_requires=">=3.10". --------- Co-authored-by: aw920h Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- hopper/setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hopper/setup.py b/hopper/setup.py index 87f6f45af97..e13d9460785 100755 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -732,12 +732,12 @@ def run(self): else { "bdist_wheel": CachedWheelsCommand, }, - python_requires=">=3.8", + python_requires=">=3.10", install_requires=[ "torch", "einops", "packaging", "ninja", ], - options={"bdist_wheel": {"py_limited_api": "cp39"}}, + options={"bdist_wheel": {"py_limited_api": "cp310"}}, ) From fe5fb1b86c0e6711f2332832760f6c3e6c34be9e Mon Sep 17 00:00:00 2001 From: Reuben Stern <107093092+reubenconducts@users.noreply.github.com> Date: Sat, 23 May 2026 21:26:36 -0400 Subject: [PATCH 15/21] [Cute,Flex,Sm100] vectorized mask_mod (#2261) * vectorized mask mod application for existing mask mod signatures * add vectorized mask mod examples, get vectorized evaluation and application working * guard sm80/90/120 against mask_vec_size > 2 * thread mask_vec_size thru sm80/90/120 kernel * Small tweaks coverign sm90 * Small tweaks coverign sm90 --------- Co-authored-by: drisspg --- flash_attn/cute/flash_fwd.py | 14 +- flash_attn/cute/flash_fwd_sm100.py | 6 +- flash_attn/cute/flash_fwd_sm90.py | 6 +- flash_attn/cute/mask.py | 287 ++++++++++++++++++++++++----- flash_attn/cute/utils.py | 2 +- tests/cute/mask_mod_definitions.py | 256 +++++++++++++++++++++++++ tests/cute/test_mask_mod.py | 125 ++++++++++++- tests/cute/test_mask_mod_varlen.py | 139 ++++++++++++++ tests/cute/test_utils.py | 13 ++ 9 files changed, 793 insertions(+), 55 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index d1a43cfd247..7b74c2f7b0f 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -99,12 +99,18 @@ def __init__( self.score_mod = score_mod self.mask_mod = mask_mod self.qk_acc_dtype = Float32 - self.vec_size: cutlass.Constexpr = getattr( + self.score_vec_size: cutlass.Constexpr = getattr( score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 ) - if self.vec_size > 2: + if self.score_vec_size > 2: raise ValueError( - f"score_mod vec_size {self.vec_size} not supported on Sm80/90/120 " + f"score_mod vec_size {self.score_vec_size} not supported on Sm80/90/120 " + "due to accumulator thread ownership pattern." + ) + self.mask_vec_size: cutlass.Constexpr = getattr(mask_mod, "__vec_size__", 1) + if self.mask_vec_size > 1: + raise ValueError( + f"mask_mod vec_size {self.mask_vec_size} not supported on Sm80/90/120 " "due to accumulator thread ownership pattern." ) self.arch = BaseDSL._get_dsl().get_arch_enum() @@ -1211,7 +1217,7 @@ def apply_score_mod( batch_idx, head_idx, softmax_scale, - self.vec_size, + self.score_vec_size, self.qk_acc_dtype, aux_tensors, fastdiv_mods, diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index dc99022c2a5..cd5dd1eb7c9 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -185,9 +185,10 @@ def __init__( ) self.score_mod = score_mod self.mask_mod = mask_mod - self.vec_size: cutlass.Constexpr = getattr( + self.score_vec_size: cutlass.Constexpr = getattr( score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 ) + self.mask_vec_size: cutlass.Constexpr = getattr(mask_mod, "__vec_size__", 1) # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f @@ -1955,6 +1956,7 @@ def softmax_loop( batch_idx=batch_idx, head_idx=head_idx, aux_tensors=aux_tensors, + vec_size=self.mask_vec_size, ) # Recompute fastdiv_mods if necessary @@ -3127,7 +3129,7 @@ def apply_score_mod( batch_idx, head_idx, softmax.softmax_scale, - self.vec_size, + self.score_vec_size, self.qk_acc_dtype, aux_tensors, fastdiv_mods, diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index 23f92181166..3d57d6718fc 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -557,7 +557,9 @@ def kernel( blocksparse_tensors.cu_total_m_blocks if blocksparse_tensors is not None else None ), mCuBlockIdxOffsets=( - blocksparse_tensors.cu_block_idx_offsets if blocksparse_tensors is not None else None + blocksparse_tensors.cu_block_idx_offsets + if blocksparse_tensors is not None + else None ), # Don't need to pass in tile_mn because we won't access offset_padded ) @@ -1508,7 +1510,7 @@ def apply_score_mod( batch_idx, head_idx, softmax_scale, - self.vec_size, + self.score_vec_size, self.qk_acc_dtype, aux_tensors, fastdiv_mods, diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index daa2e9c2d5c..47e1290fdd8 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -102,6 +102,27 @@ def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32: return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep) +@cute.jit +def apply_packed_mask_chunk( + X: cute.Tensor, + chunk_idx: cutlass.Constexpr[int], + mask: Uint32, +) -> None: + """Apply one 32-bit keep mask to one 32-column chunk. + + The one-iteration chunk loop keeps the same lowering pattern as mask_r2p_lambda. + """ + ncol = const_expr(cute.size(X.shape)) + col_base = chunk_idx * MASK_R2P_CHUNK_SIZE + for s in cutlass.range_constexpr(1): + for i in cutlass.range_constexpr( + min(MASK_R2P_CHUNK_SIZE, ncol - col_base - s * MASK_R2P_CHUNK_SIZE) + ): + in_bound = cutlass.Boolean(mask & (Uint32(1) << i)) + c = col_base + s * MASK_R2P_CHUNK_SIZE + i + X[c] = X[c] if in_bound else -Float32.inf + + @dataclass(frozen=True) class AttentionMask: tile_m: cutlass.Constexpr[int] @@ -369,6 +390,192 @@ def mask_gen_fn(s: int) -> Uint32: else acc_S_mn[r, c] ) + @cute.jit + def apply_mask_mod_sm100_scalar( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + m_block: Int32, + n_block: Int32, + mask_seqlen: cutlass.Constexpr[bool], + mask_mod: cutlass.Constexpr[Callable], + batch_idx: Int32, + head_idx: Int32, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + head_divmod=None, + check_q_boundary: bool = False, + ) -> None: + """Apply a scalar FlexAttention mask_mod to an SM100 accumulator fragment. + + Each accumulator lane calls mask_mod once with logical (batch, head, q, kv) + indices. Pack-GQA rows are converted back to logical q/head indices before + the call. When aux tensors are present, indices are wrapped with fastdiv so + mask_mod never reads outside the per-example auxiliary storage. + """ + has_fastdiv = const_expr( + fastdiv_mods is not None and fastdiv_mods[0] is not None and fastdiv_mods[1] is not None + ) + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) + ncol = const_expr(cute.size(tScS_t2r.shape)) + + for i in cutlass.range_constexpr(ncol): + row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1] + col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] + global_row = row_coord + m_block * self.tile_m + global_col = col_coord + n_block * self.tile_n + + if const_expr(self.qhead_per_kvhead_packgqa != 1): + assert head_divmod is not None + mask_row, head_offset = divmod(global_row, head_divmod) + head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset + else: + head_idx_for_mod = head_idx + mask_row = global_row + + mask_row_for_mod = mask_row + if const_expr(has_fastdiv and aux_tensors is not None): + if check_q_boundary: + _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) + global_col_for_mod = global_col + if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None): + _, global_col_for_mod = divmod(global_col, fastdiv_mods[1]) + + head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32) + mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32) + kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32) + mask_value = mask_mod( + batch_idx_ssa, + head_idx_ssa, + mask_row_ssa, + kv_idx_ssa, + self.seqlen_info, + aux_tensors, + ) + cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) + acc_S[i] = acc_S[i] if cond else -Float32.inf + if const_expr(mask_seqlen): + acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i] + if check_q_boundary: + acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i] + + @cute.jit + def apply_mask_mod_sm100_vector( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + m_block: Int32, + n_block: Int32, + mask_seqlen: cutlass.Constexpr[bool], + mask_mod: cutlass.Constexpr[Callable], + batch_idx: Int32, + head_idx: Int32, + vec_size: cutlass.Constexpr[int], + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + head_divmod=None, + check_q_boundary: bool = False, + ) -> None: + """Apply a vectorized FlexAttention mask_mod to an SM100 fragment. + + mask_mod receives vec_size adjacent KV indices for one logical q row and + returns bit-packed Uint32 keep masks. Low bits correspond to lower KV + indices. The packed masks are combined with sequence-boundary checks, then + applied in 32-column chunks so the final masking lowers to R2P. + """ + has_fastdiv = const_expr( + fastdiv_mods is not None and fastdiv_mods[0] is not None and fastdiv_mods[1] is not None + ) + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) + ncol = const_expr(cute.size(tScS_t2r.shape)) + mask_vals_per_apply = const_expr(max(1, vec_size // 32)) + calls_per_apply = const_expr(max(1, 32 // vec_size)) + n_calls = const_expr(cute.ceil_div(ncol, vec_size)) + mask_vals = cute.make_rmem_tensor(mask_vals_per_apply, dtype=cutlass.Uint32) + + # Accumulate enough vector mask_mod calls to produce 32-bit chunks that + # apply_packed_mask_chunk can lower to R2P. + for s in cutlass.range_constexpr(n_calls): + if const_expr(s % calls_per_apply == 0): + for c in cutlass.range_constexpr(mask_vals_per_apply): + mask_vals[c] = cutlass.Uint32(0) + i = s * vec_size + row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1] + col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] + global_row = row_coord + m_block * self.tile_m + global_col = col_coord + n_block * self.tile_n + if const_expr(self.qhead_per_kvhead_packgqa != 1): + assert head_divmod is not None + mask_row, head_offset = divmod(global_row, head_divmod) + head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset + else: + head_idx_for_mod = head_idx + mask_row = global_row + mask_row_for_mod = mask_row + if const_expr(has_fastdiv and aux_tensors is not None): + if check_q_boundary: + _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) + + head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32).broadcast_to( + (vec_size,) + ) + mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32).broadcast_to( + (vec_size,) + ) + batch_idx_ssa_call = batch_idx_ssa.broadcast_to((vec_size,)) + kv_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) + + # Build the per-lane KV indices for this vectorized mask_mod call. + for j in cutlass.range_constexpr(min(vec_size, ncol - i)): + col_j_coord = tScS_t2r[i + j][1] if not self.swap_AB else tScS_t2r[i + j][0] + col_j_global = col_j_coord + n_block * self.tile_n + col_j_for_mod = col_j_global + if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None): + _, col_j_for_mod = divmod(col_j_global, fastdiv_mods[1]) + kv_idx_vec[j] = col_j_for_mod + kv_idx_ssa = kv_idx_vec.load() + + # mask_value is already bit-packed by the vectorized mask_mod. + mask_value = mask_mod( + batch_idx_ssa_call, + head_idx_ssa, + mask_row_ssa, + kv_idx_ssa, + self.seqlen_info, + aux_tensors, + ) + + # For vec_size < 32, multiple mask_mod calls fill one R2P chunk. + bit_offset = const_expr((s % calls_per_apply) * vec_size) + seqlen_thresh_call = ( + self.seqlen_k - global_col if const_expr(mask_seqlen) else cutlass.Int32(0) + ) + q_in_bounds = mask_row < self.seqlen_q if check_q_boundary else cutlass.Boolean(True) + for c in cutlass.range_constexpr(mask_vals_per_apply): + mask_val = mask_value[c] + if const_expr(vec_size < 32): + lane_keep = utils.shr_u32( + cutlass.Uint32(0xFFFFFFFF), + cutlass.Uint32(32 - vec_size), + ) + mask_val = mask_val & lane_keep + if const_expr(mask_seqlen): + mask_val = mask_val & r2p_bitmask_below(seqlen_thresh_call, c) + if check_q_boundary: + mask_val = mask_val if q_in_bounds else cutlass.Uint32(0) + mask_vals[c] = mask_vals[c] | (mask_val << bit_offset) + + # Apply only when the 32-bit chunk is complete, or at the tile tail. + is_last_in_apply = const_expr(s % calls_per_apply == calls_per_apply - 1) + is_last_overall = const_expr(s == n_calls - 1) + if const_expr(is_last_in_apply or is_last_overall): + apply_idx = s // calls_per_apply + for c in cutlass.range_constexpr(mask_vals_per_apply): + chunk_idx = apply_idx * mask_vals_per_apply + c + # Skip packed chunks that start past the accumulator fragment. + if const_expr(chunk_idx * 32 < ncol): + apply_packed_mask_chunk(acc_S, chunk_idx, mask_vals[c]) + @cute.jit def apply_mask_sm100( self, @@ -386,6 +593,7 @@ def apply_mask_sm100( aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), head_divmod=None, + vec_size: cutlass.Constexpr[int] = 1, check_q_boundary: bool = False, r2p: bool = True, rBitmask: Optional[cute.Tensor] = None, @@ -429,54 +637,43 @@ def apply_mask_sm100( ) elif const_expr(not mask_causal and not mask_local and mask_mod is not None): - # Block sparse case w/ mask_mod - has_fastdiv = const_expr( - fastdiv_mods is not None - and fastdiv_mods[0] is not None - and fastdiv_mods[1] is not None + # FlexAttention mask_mod vectorization is gated on `mask_mod.__vec_size__`. + # vec_size == 1 returns a scalar Boolean. vec_size > 1 returns packed + # Uint32 mask fragments: one word per 32 evaluated columns. + assert vec_size % 32 == 0 or 32 % vec_size == 0, ( + "vec_size must divide 32 or be a multiple of 32" ) - batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) - - ncol = const_expr(cute.size(tScS_t2r.shape)) - for i in cutlass.range_constexpr(ncol): - row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1] - col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] - global_row = row_coord + m_block * self.tile_m - global_col = col_coord + n_block * self.tile_n - - if const_expr(self.qhead_per_kvhead_packgqa != 1): - assert head_divmod is not None - mask_row, head_offset = divmod(global_row, head_divmod) - head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset - else: - head_idx_for_mod = head_idx - mask_row = global_row - - mask_row_for_mod = mask_row - if const_expr(has_fastdiv and aux_tensors is not None): - if check_q_boundary: - _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) - global_col_for_mod = global_col - if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None): - _, global_col_for_mod = divmod(global_col, fastdiv_mods[1]) - - head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32) - mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32) - kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32) - mask_value = mask_mod( - batch_idx_ssa, - head_idx_ssa, - mask_row_ssa, - kv_idx_ssa, - self.seqlen_info, + if const_expr(vec_size == 1): + self.apply_mask_mod_sm100_scalar( + acc_S, + tScS_t2r, + m_block, + n_block, + mask_seqlen, + mask_mod, + batch_idx, + head_idx, aux_tensors, + fastdiv_mods, + head_divmod, + check_q_boundary, + ) + else: + self.apply_mask_mod_sm100_vector( + acc_S, + tScS_t2r, + m_block, + n_block, + mask_seqlen, + mask_mod, + batch_idx, + head_idx, + vec_size, + aux_tensors, + fastdiv_mods, + head_divmod, + check_q_boundary, ) - cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) - acc_S[i] = acc_S[i] if cond else -Float32.inf - if const_expr(mask_seqlen): - acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i] - if check_q_boundary: - acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i] else: # Causal or local causal_row_offset = self.seqlen_k - n_block * self.tile_n - self.seqlen_q diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index c8398c9a78d..3daffeeff18 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -144,7 +144,7 @@ def hash_callable( hasher = hashlib.sha256(base_hash.encode()) - for attr, val in zip(_MIXER_ATTRS, mixer_values): + for attr, val in zip(mixer_attrs, mixer_values): hasher.update(f"{attr}={val!r}".encode()) return hasher.hexdigest() diff --git a/tests/cute/mask_mod_definitions.py b/tests/cute/mask_mod_definitions.py index 38514f85e19..71cf0b9b7a5 100644 --- a/tests/cute/mask_mod_definitions.py +++ b/tests/cute/mask_mod_definitions.py @@ -496,6 +496,262 @@ def make_global_windows(seqlens_q, device="cuda"): return windows +# ============================================================================= +# Vectorized mask_mod variants (return bit-packed Uint32) +# ============================================================================= +# +# Each variant receives shape-(vec_size,) Int32 SSAs and returns a shape- +# (max(1, vec_size // 32),) Uint32 TensorSSA, where bit i of element k is the +# mask for lane (k * 32 + i). Bodies assume lane i has idx[i] = idx[0] + i, so +# they compute the packed Uint32(s) in O(1) via integer/bit arithmetic on the +# chunk-base indices instead of evaluating per lane. + + +@cute.jit +def cute_causal_mask_vec( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors: None, +) -> cute.TensorSSA: + offset = seqlen_info.seqlen_k - seqlen_info.seqlen_q + threshold = m_idx[0] + offset - n_idx[0] + cutlass.Int32(1) + m = max(cutlass.Int32(32) - threshold, cutlass.Int32(0)) + result = cute.make_rmem_tensor(1, dtype=cutlass.Uint32) + result[0] = utils.shr_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(m)) + return result.load() + + +def get_cute_causal_mask_vec(offset: int): + return cute_causal_mask_vec + + +def get_cute_sliding_window_mask_vec(window_left: int, window_right: int, offset: int): + @cute.jit + def _cute_sliding_window_mask_vec( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, + ) -> cute.TensorSSA: + runtime_offset = seqlen_info.seqlen_k - seqlen_info.seqlen_q + center = m_idx[0] + runtime_offset + lo = center - cutlass.Int32(window_left) - n_idx[0] + hi_excl = center + cutlass.Int32(window_right) - n_idx[0] + cutlass.Int32(1) + m_below = max(cutlass.Int32(32) - hi_excl, cutlass.Int32(0)) + below = utils.shr_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(m_below)) + n_above = max(lo, cutlass.Int32(0)) + above = utils.shl_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(n_above)) + result = cute.make_rmem_tensor(1, dtype=cutlass.Uint32) + result[0] = below & above + return result.load() + + return _cute_sliding_window_mask_vec + + +@cute.jit +def cute_block_diagonal_mask_vec( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + block_size = cutlass.Int32(128) + block_m = m_idx[0] // block_size + lo = block_m * block_size - n_idx[0] + hi = (block_m + cutlass.Int32(1)) * block_size - n_idx[0] + m_below = max(cutlass.Int32(32) - hi, cutlass.Int32(0)) + below = utils.shr_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(m_below)) + n_above = max(lo, cutlass.Int32(0)) + above = utils.shl_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(n_above)) + result = cute.make_rmem_tensor(1, dtype=cutlass.Uint32) + result[0] = below & above + return result.load() + + +@cute.jit +def cute_prefix_lm_mask_vec( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + prefix = cutlass.Int32(512) + hi_pref = prefix - n_idx[0] + m_below_pref = max(cutlass.Int32(32) - hi_pref, cutlass.Int32(0)) + term1_below = utils.shr_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(m_below_pref)) + row_in_prefix_mask = ( + cutlass.Uint32(0xFFFFFFFF) if m_idx[0] < prefix else cutlass.Uint32(0) + ) + term1 = term1_below & row_in_prefix_mask + hi_causal = m_idx[0] - n_idx[0] + cutlass.Int32(1) + m_below_causal = max(cutlass.Int32(32) - hi_causal, cutlass.Int32(0)) + term2 = utils.shr_u32(cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32(m_below_causal)) + result = cute.make_rmem_tensor(1, dtype=cutlass.Uint32) + result[0] = term1 | term2 + return result.load() + + +# ============================================================================= +# Packed-bitmask aux tensor mod +# ============================================================================= +# aux[0] is a (batch, max_seqlen_q, ceil(max_seqlen_k / 32)) Uint32 tensor where +# bit k of packed[b, q, c] is the mask for (b, q, c*32 + k). + + +@cute.jit +def cute_packed_mask_aux( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + packed = aux_tensors[0] + val = packed[batch[0], m_idx[0], n_idx[0] // cutlass.Int32(32)] + shift = cutlass.Uint32(n_idx[0] % cutlass.Int32(32)) + bit_set = cutlass.Boolean(utils.shr_u32(val, shift) & cutlass.Uint32(1)) + result = cute.make_rmem_tensor(n_idx.shape, dtype=cutlass.Boolean) + for j in cutlass.range_constexpr(cute.size(n_idx.shape)): + result[j] = bit_set + return result.load() + + +def get_cute_packed_mask_aux_vec(vec_size: int): + """Vec packed-mask, specialized for `vec_size`. For vec_size > 32, requires + `aux_tensors[0].__assumed_align__ = 16` and num_words divisible by 4.""" + if vec_size <= 32: + + @cute.jit + def _mod( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, + ) -> cute.TensorSSA: + packed = aux_tensors[0] + base = n_idx[0] // cutlass.Int32(32) + val = packed[batch[0], m_idx[0], base] + shift = cutlass.Uint32(n_idx[0] % cutlass.Int32(32)) + result = cute.make_rmem_tensor(1, dtype=cutlass.Uint32) + result[0] = utils.shr_u32(val, shift) + return result.load() + else: + num_words = vec_size // 32 + + @cute.jit + def _mod( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, + ) -> cute.TensorSSA: + packed = aux_tensors[0] + b_str, m_str, _ = packed.stride + packed_aligned = cute.make_tensor( + packed.iterator, + cute.make_layout( + packed.shape, + stride=( + cute.assume(b_str, divby=4), + cute.assume(m_str, divby=4), + 1, + ), + ), + ) + packed_row = packed_aligned[batch[0], m_idx[0], None] + packed_tiled = cute.flat_divide(packed_row, (num_words,)) + base = n_idx[0] // cutlass.Int32(32) + packed_chunk = packed_tiled[None, base // cutlass.Int32(num_words)] + loaded = cute.make_rmem_tensor_like(packed_chunk) + cute.autovec_copy(packed_chunk, loaded) + result = cute.make_rmem_tensor(num_words, dtype=cutlass.Uint32) + for k in cutlass.range_constexpr(num_words): + result[k] = cutlass.Uint32(loaded[k]) + return result.load() + + return _mod + + +def make_packed_mask_aux_tensor( + batch: int, + seqlen_q: int, + seqlen_k: int, + density: float = 0.5, + device="cuda", + seed: int = 0, +): + """Random Uint32 bit-packed mask. num_words is rounded up to a multiple of 4 + so each row is 16-byte aligned (LDG.E.128 requirement at vec_size=128).""" + g = torch.Generator(device=device).manual_seed(seed) + num_words = ((seqlen_k + 31) // 32 + 3) // 4 * 4 + bools = ( + torch.rand(batch, seqlen_q, num_words * 32, device=device, generator=g) + < density + ) + bools = bools.reshape(batch, seqlen_q, num_words, 32) + powers = 1 << torch.arange(32, device=device, dtype=torch.int64) + packed = (bools.to(torch.int64) * powers).sum(-1).to(torch.uint32) + packed.__assumed_align__ = 16 + return packed + + +VEC_MASK_FACTORIES = { + "causal": ("factory", get_cute_causal_mask_vec), + "block_causal": ("factory", get_cute_causal_mask_vec), + "sliding_window": ("factory_window", get_cute_sliding_window_mask_vec), + "block_diagonal": ("static", cute_block_diagonal_mask_vec), + "prefix_lm": ("static", cute_prefix_lm_mask_vec), + "packed_aux": ("factory_vec_size", get_cute_packed_mask_aux_vec), +} + + +# Scalar mods for vec masks not in STATIC_MASKS / PARAMETERIZED_MASK_FACTORIES. +EXTRA_SCALAR_MASKS = { + "packed_aux": cute_packed_mask_aux, +} + + +def get_vec_mask( + mask_name, seqlen_q=None, seqlen_k=None, window_size=None, vec_size=None +): + """Return a vectorized cute mask callable for `mask_name`, or None if there + is no vec form. Caller sets `__vec_size__` on the returned callable. + `vec_size` is required for masks whose body specializes on it (packed_aux).""" + if mask_name not in VEC_MASK_FACTORIES: + return None + kind, obj = VEC_MASK_FACTORIES[mask_name] + if kind == "static": + return obj + if kind == "factory_vec_size": + if vec_size is None: + raise ValueError(f"{mask_name} vec mask requires vec_size") + return obj(vec_size) + offset = ( + (seqlen_k - seqlen_q) if (seqlen_q is not None and seqlen_k is not None) else 0 + ) + if kind == "factory": + return obj(offset) + if kind == "factory_window": + if window_size is None: + raise ValueError("sliding_window vec mask requires window_size") + return obj(window_size, window_size, offset) + raise ValueError(f"unknown vec mask kind: {kind}") + + # ============================================================================= # Mask registry and factory functions # ============================================================================= diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 52f78e8d26d..a4228dc48a0 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -32,7 +32,13 @@ ) from flash_attn.cute.cache_utils import get_jit_cache from flash_attn.cute import utils -from mask_mod_definitions import get_mask_pair, random_doc_id_tensor +from mask_mod_definitions import ( + get_mask_pair, + get_vec_mask, + random_doc_id_tensor, + EXTRA_SCALAR_MASKS, + make_packed_mask_aux_tensor, +) COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] @@ -47,6 +53,7 @@ def reset_torch_state(): torch._dynamo.reset() torch.cuda.empty_cache() + def create_tensors( batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype ): @@ -828,6 +835,122 @@ def test_parameterized_masks( ) +# ============================================================================= +# Vectorized mask_mod equality tests +# Pattern: scalar mask is reference; vec mask at multiple __vec_size__ values +# must produce bit-identical output. +# ============================================================================= + +# (mask_name, window_size, needs_aux) +VEC_MASK_TEST_CASES = [ + ("causal", None, False), + ("block_causal", None, False), + ("sliding_window", 128, False), + ("block_diagonal", None, False), + ("prefix_lm", None, False), + ("packed_aux", None, True), +] + +# Vectorized mask_mod application is currently implemented for SM100/SM110 forward. +# vec_size > 32 is only supported by packed_aux (other vec mods return shape-(1,) Uint32). +VEC_MASK_SIZES_TO_CHECK_EQUALITY = [2, 8, 32, 128] + + +def _run_mask_mod_only(q, k, v, mask_mod, aux_tensors, pack_gqa): + out = torch.empty_like(q) + _flash_attn_fwd( + q=q, + k=k, + v=v, + out=out, + lse=None, + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=1.0 / math.sqrt(q.shape[-1]), + causal=False, + softcap=None, + window_size_left=-1, + window_size_right=-1, + learnable_sink=None, + tile_mn=(128, 128), + pack_gqa=pack_gqa, + _arch=None, + score_mod=None, + mask_mod=mask_mod, + block_sparse_tensors=None, + return_lse=False, + aux_tensors=aux_tensors, + ) + return out + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [(128, 128), (256, 256), (113, 203), (256, 512), (1024, 1024)], +) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 4), (4, 2)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("mask_case", VEC_MASK_TEST_CASES) +def test_cute_mask_mod_vectorized( + seqlen_q, seqlen_k, qhead_per_kvhead, num_kv_heads, dtype, mask_case +): + """Tests equality between scalar and vectorized versions of mask mods.""" + if COMPUTE_CAPABILITY not in (10, 11): + pytest.skip("vectorized mask_mod application is SM100/SM110-only") + mask_name, window_size, needs_aux = mask_case + torch.manual_seed(42) + + num_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + head_dim = 128 + batch_size = 2 + + q = torch.randn( + batch_size, seqlen_q, num_heads, head_dim, device="cuda", dtype=dtype + ) + k = torch.randn( + batch_size, seqlen_k, num_kv_heads, head_dim, device="cuda", dtype=dtype + ) + v = torch.randn( + batch_size, seqlen_k, num_kv_heads, head_dim, device="cuda", dtype=dtype + ) + + if needs_aux: + aux_tensors = [make_packed_mask_aux_tensor(batch_size, seqlen_q, seqlen_k)] + scalar_mod = EXTRA_SCALAR_MASKS[mask_name] + else: + aux_tensors = None + scalar_mod, _ = get_mask_pair( + mask_name, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + window_size=window_size, + ) + + out_ref = _run_mask_mod_only(q, k, v, scalar_mod, aux_tensors, pack_gqa) + + for vec_size in VEC_MASK_SIZES_TO_CHECK_EQUALITY: + if vec_size > 32 and mask_name != "packed_aux": + continue + vec_mod = get_vec_mask( + mask_name, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + window_size=window_size, + vec_size=vec_size, + ) + if vec_mod is None: + pytest.skip(f"no vec mask for {mask_name}") + vec_mod.__vec_size__ = vec_size + out = _run_mask_mod_only(q, k, v, vec_mod, aux_tensors, pack_gqa) + assert torch.equal(out, out_ref), ( + f"{mask_name} vec_size={vec_size}: output mismatch vs scalar reference" + ) + + def test_sm100_block_sparse_sink_all_masked(): """Block-sparse regression for the sink path""" if torch.cuda.get_device_capability()[0] != 10: diff --git a/tests/cute/test_mask_mod_varlen.py b/tests/cute/test_mask_mod_varlen.py index 70d08763cb3..6e37e9ed4b8 100644 --- a/tests/cute/test_mask_mod_varlen.py +++ b/tests/cute/test_mask_mod_varlen.py @@ -23,9 +23,12 @@ from flash_attn.cute.compute_block_sparsity import compute_block_sparsity from mask_mod_definitions import ( get_mask_pair, + get_vec_mask, random_doc_id_tensor, STATIC_MASKS, PARAMETERIZED_MASK_FACTORIES, + EXTRA_SCALAR_MASKS, + make_packed_mask_aux_tensor, cute_global_packed_doc_mask, cute_global_ima_mask, cute_global_causal_window_mask, @@ -613,6 +616,142 @@ def test_varlen_global_masks(seqlens_q, seqlens_k, mask_name): ) +# ============================================================================= +# Vectorized mask_mod equality tests (varlen) +# Pattern: scalar mask is reference; vec mask at multiple __vec_size__ values +# must produce bit-identical output. +# ============================================================================= + +# (mask_name, window_size, needs_aux) +VEC_MASK_TEST_CASES = [ + ("causal", None, False), + ("block_causal", None, False), + ("sliding_window", 128, False), + ("block_diagonal", None, False), + ("prefix_lm", None, False), + ("packed_aux", None, True), +] + +# Vectorized mask_mod application is currently implemented for SM100/SM110 forward. +# vec_size > 32 is only supported by packed_aux (other vec mods return shape-(1,) Uint32). +VEC_MASK_SIZES_TO_CHECK_EQUALITY = [2, 8, 32, 128] + + +def _run_varlen_mask_only( + q, k, v, cu_seqlens_q, cu_seqlens_k, mask_mod, aux_tensors, pack_gqa +): + out = torch.empty_like(q) + _flash_attn_fwd( + q=q, + k=k, + v=v, + out=out, + lse=None, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=1.0 / math.sqrt(q.shape[-1]), + causal=False, + softcap=None, + window_size_left=-1, + window_size_right=-1, + learnable_sink=None, + tile_mn=(128, 128), + pack_gqa=pack_gqa, + _arch=None, + score_mod=None, + mask_mod=mask_mod, + block_sparse_tensors=None, + return_lse=False, + aux_tensors=aux_tensors, + ) + return out + + +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS_SMOKE) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa"]) +@pytest.mark.parametrize("mask_case", VEC_MASK_TEST_CASES) +def test_varlen_mask_mod_vectorized(seqlens_q, seqlens_k, dtype, kv_mode, mask_case): + """Tests equality between scalar and vectorized mask mods on varlen inputs.""" + if COMPUTE_CAPABILITY not in (10, 11): + pytest.skip("vectorized mask_mod application is SM100/SM110-only") + mask_name, window_size, needs_aux = mask_case + + if mask_name == "block_causal": + offsets = [sk - sq for sq, sk in zip(seqlens_q, seqlens_k)] + if len(set(offsets)) > 1: + pytest.skip( + "block_causal captures offset as compile-time constant; " + "varlen with different per-sequence offsets not supported" + ) + if mask_name == "sliding_window": + for sq, sk in zip(seqlens_q, seqlens_k): + if sq > sk: + pytest.skip( + "sliding_window requires seqlen_q <= seqlen_k for each sequence" + ) + + torch.manual_seed(42) + num_heads = 8 + if kv_mode == "gqa": + if COMPUTE_CAPABILITY < 9: + pytest.xfail("pack_gqa requires SM90+") + num_kv_heads = 2 + else: + num_kv_heads = num_heads + pack_gqa = num_heads != num_kv_heads + head_dim = 128 + + q, k, v, cu_seqlens_q, cu_seqlens_k = setup_varlen_tensors( + seqlens_q, seqlens_k, num_heads, num_kv_heads, head_dim, dtype + ) + + batch_size = len(seqlens_q) + max_seqlen_q = max(seqlens_q) + max_seqlen_k = max(seqlens_k) + + if needs_aux: + aux_tensors = [ + make_packed_mask_aux_tensor(batch_size, max_seqlen_q, max_seqlen_k) + ] + scalar_mod = EXTRA_SCALAR_MASKS[mask_name] + else: + aux_tensors = None + scalar_mod, _ = get_mask_pair( + mask_name, + seqlen_q=max_seqlen_q, + seqlen_k=max_seqlen_k, + window_size=window_size, + ) + + out_ref = _run_varlen_mask_only( + q, k, v, cu_seqlens_q, cu_seqlens_k, scalar_mod, aux_tensors, pack_gqa + ) + + for vec_size in VEC_MASK_SIZES_TO_CHECK_EQUALITY: + if vec_size > 32 and mask_name != "packed_aux": + continue + vec_mod = get_vec_mask( + mask_name, + seqlen_q=max_seqlen_q, + seqlen_k=max_seqlen_k, + window_size=window_size, + vec_size=vec_size, + ) + if vec_mod is None: + pytest.skip(f"no vec mask for {mask_name}") + vec_mod.__vec_size__ = vec_size + out = _run_varlen_mask_only( + q, k, v, cu_seqlens_q, cu_seqlens_k, vec_mod, aux_tensors, pack_gqa + ) + assert torch.equal(out, out_ref), ( + f"{mask_name} vec_size={vec_size}: output mismatch vs scalar reference" + ) + + # ============================================================================= # Block sparsity end-to-end tests # ============================================================================= diff --git a/tests/cute/test_utils.py b/tests/cute/test_utils.py index 189eb86957d..4ec077933cc 100644 --- a/tests/cute/test_utils.py +++ b/tests/cute/test_utils.py @@ -160,6 +160,19 @@ def tracking_sha256(*args, **kwargs): assert call_tracker["sha256"] == 0, "sha256 should not be called" assert result == "wrapped-fast-hash" + def test_vec_size_affects_hash(self): + def mask_mod(_b, _h, q_idx, kv_idx): + return q_idx >= kv_idx + + base_hash = hash_callable(mask_mod, set_cute_hash=False) + mask_mod.__vec_size__ = 16 + vec16_hash = hash_callable(mask_mod, set_cute_hash=False) + mask_mod.__vec_size__ = 32 + vec32_hash = hash_callable(mask_mod, set_cute_hash=False) + + assert base_hash != vec16_hash + assert vec16_hash != vec32_hash + def test_closure_values_affect_hash(self): """Functions with different closure values should have different hashes.""" value1 = 10 From 2d5d5a1c95a34884b3708ef66e93db482c3a5902 Mon Sep 17 00:00:00 2001 From: Junrong Lin <33685709+ocss884@users.noreply.github.com> Date: Mon, 25 May 2026 06:55:10 +0800 Subject: [PATCH 16/21] Update architecture assertion for SM 10.x and 11.x (#2572) --- flash_attn/cute/flash_fwd_sm100.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index cd5dd1eb7c9..576238bcafb 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -159,7 +159,8 @@ def __init__( assert self.split_P_arrive % 32 == 0 assert self.split_P_arrive < self.n_block_size self.arch = BaseDSL._get_dsl().get_arch_enum() - assert self.arch >= Arch.sm_100 and self.arch <= Arch.sm_110f, "Only SM 10.x and 11.x are supported" + assert self.arch.is_family_of(Arch.sm_100f) or self.arch.is_family_of(Arch.sm_110f), \ + "Only SM 10.x and 11.x are supported" self.cta_group_size = 2 if self.use_2cta_instrs else 1 # cta_tiler M includes only 1 CTA, the scheduler will take into account the cluster shape From 59cf5378123f26fad50c2f12b97cec74558feb64 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Mon, 25 May 2026 23:26:17 -0700 Subject: [PATCH 17/21] Include sm_110 in Blackwell-family arch gating (follow-up to #2572) (#2590) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix bwd postprocess 2CTA gating to include sm_11x The 2CTA gating in flash_bwd_postprocess.py used `arch // 10 == 10`, which only matches SM 10.x (B100/B200/B300) and misses SM 11.x (Thor). The rest of the codebase (e.g. interface.py:549, 563, 834) consistently gates Blackwell-family 2CTA features as `arch // 10 in [10, 11]`. Bring the two postprocess sites in line with that convention. Flagged by @jayhshah in #2572 follow-up discussion. * Include sm_110 in interface.py Blackwell-family heuristics Three sites in interface.py gate Blackwell-family behavior using `arch // 10 == 10`, which appears inconsistent with the rest of the file's `arch // 10 in [10, 11]` convention (used at lines 549, 563, 834, 974, 1035, etc.): - L533: `q_stage` heuristic for Blackwell forward - L579: `use_dedicated_hd256_kernel` (forward) - L1335: `use_dedicated_hd256_kernel` (backward) The dispatch in `_flash_attn_fwd` already routes both sm_10x and sm_11x through the same `FlashAttentionForwardSm100` / MLA classes, so these gates likely should treat them the same. NOTE FOR REVIEWERS: I'm not certain these are all oversight vs. intentional SM100-only paths. If any of them is intentional, please flag so I can revert just that hunk. The FP8 assert at L480 is left untouched on purpose — its error message reads as deliberate. * Apply ruff format to flash_bwd_sm100.py Pre-existing format drift surfaced by pre-commit. Not in the cute_exclude pattern, so it gets auto-fixed when other files in flash_attn/cute/ are touched in the same commit chain. --- flash_attn/cute/flash_bwd_postprocess.py | 4 ++-- flash_attn/cute/flash_bwd_sm100.py | 4 +--- flash_attn/cute/interface.py | 8 ++++---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 76c856221c5..94f0c88d817 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -63,7 +63,7 @@ def __init__( self.num_threads = num_threads self.AtomLayoutMdQ = AtomLayoutMdQ self.dQ_swapAB = dQ_swapAB - self.use_2cta_instrs = use_2cta_instrs and arch // 10 == 10 and head_dim != 64 + self.use_2cta_instrs = use_2cta_instrs and arch // 10 in [10, 11] and head_dim != 64 self.cluster_size = cluster_size @staticmethod @@ -373,7 +373,7 @@ def kernel( seqlen_q = seqlen.seqlen_q seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) - if const_expr(self.arch // 10 == 10 and self.use_2cta_instrs): + if const_expr(self.arch // 10 in [10, 11] and self.use_2cta_instrs): # 2-CTA: remap dQaccum layout into TMEM view before writing sdQ num_reduce_threads = self.num_threads thr_mma_dsk = tiled_mma.get_slice(tidx) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 81462e50afd..061ede3d983 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -1423,9 +1423,7 @@ def kernel( ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) - AttentionMaskCls = self._generate_attention_mask_cls( - window_size_left, window_size_right - ) + AttentionMaskCls = self._generate_attention_mask_cls(window_size_left, window_size_right) # EMPTY # (15) if warp_idx == self.empty_warp_id: diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 189ae1faca7..5e8674bf1ad 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -530,7 +530,7 @@ def _flash_attn_fwd( if cu_seqlens_k is None and seqused_k is None: min_seqlen_k = seqlen_k seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead - if arch // 10 == 10: + if arch // 10 in [10, 11]: q_stage = 2 if seqlen_q_packgqa > tile_m else 1 else: q_stage = 1 @@ -575,8 +575,8 @@ def _flash_attn_fwd( and (tile_m % qhead_per_kvhead == 0 or not pack_gqa) ) - # hd=256 2CTA forward uses dedicated kernel (SM100 only) - use_dedicated_hd256_kernel = arch // 10 == 10 and head_dim == 256 and head_dim_v == 256 + # hd=256 2CTA forward uses dedicated kernel (Blackwell family) + use_dedicated_hd256_kernel = arch // 10 in [10, 11] and head_dim == 256 and head_dim_v == 256 use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel if softcap is not None: @@ -1332,7 +1332,7 @@ def _flash_attn_bwd( cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1 use_2cta_instrs = cluster_size==2 - use_dedicated_hd256_kernel = arch // 10 == 10 and head_dim == 256 and head_dim_v == 256 + use_dedicated_hd256_kernel = arch // 10 in [10, 11] and head_dim == 256 and head_dim_v == 256 use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ From 6c4f74fb338e0c3cdb07ac6f5eab5f54fc367c15 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Mon, 25 May 2026 23:50:22 -0700 Subject: [PATCH 18/21] Use is_family_of for sm_90 and sm_103 arch checks (#2589) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Use is_family_of for sm_90 and sm_103 arch checks Follow-up to #2572 — apply the same is_family_of pattern to the two remaining range-style arch checks for consistency: - flash_fwd_sm90.py:69 (SM 9.x assert) - flash_fwd_sm100.py:195 (is_sm103 flag) Same semantic narrowing as #2572: bare-base SMs (sm_90, sm_103) are excluded. These kernels rely on wgmma / UMMA / 2CTA paths that require the a/f PTX variant anyway, so bare-base targets could not compile. * Clarify is_sm103 forward-inclusive semantics is_family_of(sm_103f) also matches any future sm_10x with x > 3, not just sm_103a/f. This was raised in PR review (@ocss884) — adding an inline comment clarifying that this forward-inclusive behavior is intentional: the flag gates ex2 emulation, sm_103 (B300) has fast hardware ex2, and later Blackwell variants in the same family are assumed to inherit it. No code-behavior change. --- flash_attn/cute/flash_fwd_sm100.py | 6 +++++- flash_attn/cute/flash_fwd_sm90.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 576238bcafb..57755d12cb9 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -192,7 +192,11 @@ def __init__( self.mask_vec_size: cutlass.Constexpr = getattr(mask_mod, "__vec_size__", 1) # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) - is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f + # NOTE: is_family_of also matches any future sm_10x with x > 3 — intentional. + # The flag gates ex2 emulation; sm_103 (B300) has fast hardware ex2 and later + # Blackwell variants are assumed to inherit this, so forward-inclusion is correct + # despite the literal `is_sm103` name. + is_sm103 = self.arch.is_family_of(Arch.sm_103f) self.is_sm103 = is_sm103 # enable_ex2_emu is derived: True if tuning config has freq > 0, else fallback to default logic _default_enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103 diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index 3d57d6718fc..93bccfa715b 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -66,7 +66,7 @@ def __init__( "Paged KV does not support irregular head dim" ) self.cluster_shape_mn = (1, 1) - assert self.arch >= Arch.sm_90 and self.arch <= Arch.sm_90a, "Only SM 9.x is supported" + assert self.arch.is_family_of(Arch.sm_90a), "Only SM 9.x is supported" def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( From 59f01d6e1a1655a148ed4b22b5d4fbb9da2c2cf0 Mon Sep 17 00:00:00 2001 From: Strahinja Stamenkovic Date: Wed, 27 May 2026 16:25:52 +0200 Subject: [PATCH 19/21] Bump AITER submodule to commit 3b2e6f4 (#2540) * Bump aiter submodule commit Co-authored-by: sstamenk <170634954+sstamenk@users.noreply.github.com> * Bump aiter submodule to 3b2e6f48ce97e1d494e8b3f1af5c65f74e304b28 (#2) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: sstamenk <170634954+sstamenk@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: sstamenk <170634954+sstamenk@users.noreply.github.com> --- third_party/aiter | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/aiter b/third_party/aiter index b4b75165fbd..3b2e6f48ce9 160000 --- a/third_party/aiter +++ b/third_party/aiter @@ -1 +1 @@ -Subproject commit b4b75165fbd2456dfd0f074c5b2ef91bc87d97e5 +Subproject commit 3b2e6f48ce97e1d494e8b3f1af5c65f74e304b28 From 0bbb25a3a5ad3c58c029b3d287d6c9af56a5cad5 Mon Sep 17 00:00:00 2001 From: Johnsonms Date: Wed, 27 May 2026 21:00:08 -0700 Subject: [PATCH 20/21] Clamp kv_stage to avoid SMEM overflow for small head_dims on SM100 (#2594) * Clamp kv_stage to avoid SMEM overflow for small head_dims on SM100 Fixes #2591. The unbounded formula at flash_fwd_sm100.py:335 ignores per-stage state (mbarriers, sScale, pipeline counters) and yields kv_stage values that overflow the sm_100a 227 KB SMEM cap when head_dim_padded=16 (head_dim in {8, ..., 16}). Repro: hd=8/16 + seqlen >= 256 + bf16 fails with cudaErrorInvalidValue ("launch shared memory exceeds current GPU arch sm_100a allowed. Allocated: 233472 bytes. Max: 232448 bytes."). Clamp kv_stage at 32. Surgical to the broken case: the unbounded formula maxes at 26 stages for head_dim_padded >= 32, and the 2CTA gate at interface.py:572 restricts 2CTA to hd_padded in {128, 192} (both no-op), so the clamp only fires at hd_padded in {8, 16}. Verified across 24 configs (hd in {8,16,32,64,96,128} x causal in {T,F} x seqlen in {128,2048}) on B200 with max_err vs torch SDPA <= 0.0078. * Add test_flash_attn_small_head_dim regression test The main test_flash_attn_output parametrizes d over {64, 96, 128, 192, 256} and never exercises head_dim < 64, even though _validate_head_dims accepts head_dim >= 8 for sm_100/110. That coverage gap let the SMEM-overflow bug in #2591 slip through. This focused test covers d in {8, 16, 32} x causal x seqlen in {128, 2048}. The seqlen=2048 cases push q_stage 1->2 (the actual bug trigger); the seqlen=128 cases also exercise the q_stage=1 boundary that fits on main today but is structurally adjacent. d=32 serves as a canary against any future tighter kv_stage clamp regressing it. --- flash_attn/cute/flash_fwd_sm100.py | 5 +++- tests/cute/test_flash_attn.py | 44 ++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 57755d12cb9..82638f341cd 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -336,7 +336,10 @@ def _setup_attributes(self): smem_size_k_per_stage = self.n_block_size * self.head_dim_padded * self.k_dtype.width // 8 smem_size_v_per_stage = self.n_block_size * self.head_dim_v_padded * self.v_dtype.width // 8 smem_size_kv_per_stage = max(smem_size_k_per_stage, smem_size_v_per_stage) // self.cta_group_size - kv_stage = (224 * 1024 - smem_size_q_o) // smem_size_kv_per_stage + # Cap small head_dim from over-staging: the 224*1024 budget undercounts + # per-stage state, so at hd_padded=16 the unbounded formula picks 52 stages + # and overflows the 227 KB SMEM cap. No-op for hd_padded >= 32 (max 26). + kv_stage = min((224 * 1024 - smem_size_q_o) // smem_size_kv_per_stage, 32) if self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and kv_stage == 2: # For hdim 192,128, we can fit 3 stages if we use uneven_kv_smem kv_stage = 3 diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 764d7123681..bf881efe1c0 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -438,6 +438,50 @@ def test_flash_attn_output( ).abs().max().item() + dv_atol +# Regression test for #2591: SMEM overflow at small head_dims on SM100. The main +# test_flash_attn_output skips d < 64, but _validate_head_dims accepts head_dim >= 8 +# for sm_100/110, so this path needs coverage. Trigger requires +# seqlen_q_packgqa > tile_m to push q_stage 1->2. +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [8, 16, 32]) +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (2048, 2048)]) +@retry_on_oom +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_small_head_dim(seqlen_q, seqlen_k, d, causal, dtype): + device = "cuda" + seed = 0 + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.empty_cache() + torch.cuda.synchronize() + batch_size = 2 + nheads = 2 + nheads_kv = nheads + dtype_ref = dtype + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ).requires_grad_() + k_ref = torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ).requires_grad_() + v_ref = torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ).requires_grad_() + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + out_ref, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal) + out_pt, _ = attention_ref( + q_ref, k_ref, v_ref, None, None, causal=causal, upcast=False, reorder_ops=True + ) + out, _ = flash_attn_func(q, k, v, causal=causal) + if is_fake_mode(): + return + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert (out - out_ref).abs().max().item() <= 2 * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) From 316617c5759673d70a19a95219f84e56fbf33654 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Sat, 30 May 2026 13:42:28 -0500 Subject: [PATCH 21/21] Fix pre-commit Signed-off-by: Matthew Bonanni --- flash_attn/cute/flash_fwd_combine.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py index 75bf77c89b4..84cf365d46a 100644 --- a/flash_attn/cute/flash_fwd_combine.py +++ b/flash_attn/cute/flash_fwd_combine.py @@ -69,7 +69,10 @@ def can_implement( ) -> bool: """Check if the kernel can be implemented with the given parameters.""" if dtype not in [ - cutlass.Float16, cutlass.BFloat16, cutlass.Float32, cutlass.Float8E4M3FN, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float32, + cutlass.Float8E4M3FN, ]: return False if dtype_partial not in [cutlass.Float16, cutlass.BFloat16, Float32]: