-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Block Sparsity and Flex Attention mask mod support #1942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
jayhshah
merged 17 commits into
Dao-AILab:main
from
reubenconducts:rstern/flex-mask-mod
Oct 21, 2025
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
f6880af
clean up and rebase for PR
reubenconducts e7121a3
add mask mod tests
reubenconducts 864de11
add benchmarking files
reubenconducts d300e33
refactor for better style
reubenconducts 9fbc2d4
remove extraneous csrc
reubenconducts b81eaa4
type hint buffers
reubenconducts e05ec82
Merge remote-tracking branch 'upstream/main' into rstern/flex-mask-mo…
reubenconducts 5d5bb09
refactor: order of non/overlap and modify blocksparse producer to agr…
reubenconducts a17bb58
change variable name back to buffers
reubenconducts 7c563ac
remove unnecessary variable in first_half_block
reubenconducts b5f7082
restore erroneous packgqa deletion
reubenconducts ab5c024
add blocksparsity and mask_mod asserts to interface.py
reubenconducts 06820e8
fix rebase issues
reubenconducts db0ea95
Restore submodule and reset pointer to upstream/main
reubenconducts 41ba160
rename cutlass.const_expr to const_expr
reubenconducts c6e0d6b
support fully masked m blocks (i.e. skipped tiles)
reubenconducts d28e6a8
remove outdated commented code
reubenconducts File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. | ||
| # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. | ||
| # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. | ||
|
|
||
| # Supported features: | ||
| # - BF16 & FP16 dtype | ||
|
|
@@ -73,7 +74,12 @@ def _flash_attn_fwd( | |
| num_threads: int = 384, | ||
| pack_gqa: Optional[bool] = None, | ||
| _compute_capability: Optional[int] = None, | ||
| score_mod: Callable | None = None, | ||
| score_mod: Optional[Callable] = None, | ||
| mask_mod: Optional[Callable] = None, | ||
| full_block_cnt: Optional[torch.Tensor] = None, | ||
| full_block_idx: Optional[torch.Tensor] = None, | ||
| mask_block_cnt: Optional[torch.Tensor] = None, | ||
| mask_block_idx: Optional[torch.Tensor] = None, | ||
| return_lse: bool = False, | ||
| out: Optional[torch.Tensor] = None, | ||
| lse: Optional[torch.Tensor] = None, | ||
|
|
@@ -135,7 +141,22 @@ def _flash_attn_fwd( | |
| if learnable_sink is not None: | ||
| assert learnable_sink.shape == (num_head,) | ||
| assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" | ||
| assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, learnable_sink)), "inputs must be on CUDA device" | ||
| for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]: | ||
| if t is not None: | ||
| assert t.dtype == torch.int32, "blocksparse mask tensors must be int32" | ||
| assert t.stride(0) == 1, "blocksparse mask tensors must be contiguous" | ||
| assert all( | ||
| t is None or t.is_cuda | ||
| for t in ( | ||
| q, k, v, | ||
| cu_seqlens_q, cu_seqlens_k, | ||
| seqused_q, seqused_k, | ||
| page_table, | ||
| learnable_sink, | ||
| full_block_cnt, full_block_idx, | ||
| mask_block_cnt, mask_block_idx, | ||
| ) | ||
| ), "inputs must be on CUDA device" | ||
| assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" | ||
| assert head_dim <= 256, "head_dim must be less than or equal to 256" | ||
| alignment = 16 // q.element_size() | ||
|
|
@@ -183,6 +204,13 @@ def _flash_attn_fwd( | |
| for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) | ||
| ] | ||
| page_table_tensor = from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) if page_table is not None else None | ||
|
|
||
| full_block_cnt_tensor = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) if full_block_cnt is not None else None | ||
| full_block_idx_tensor = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) if full_block_idx is not None else None | ||
| mask_block_cnt_tensor = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) if mask_block_cnt is not None else None | ||
| mask_block_idx_tensor = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=3) if mask_block_idx is not None else None | ||
|
|
||
|
|
||
| if causal: | ||
| window_size_right = 0 | ||
| local = window_size_left is not None or window_size_right is not None | ||
|
|
@@ -202,22 +230,44 @@ def _flash_attn_fwd( | |
| # TODO: fix the varlen case | ||
| if pack_gqa and (128 % qhead_per_kvhead != 0) or (cu_seqlens_q is not None or seqused_q is not None): | ||
| pack_gqa = False | ||
|
|
||
|
|
||
| # hash score and mask mods for compile cache | ||
| score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else None | ||
| mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else None | ||
|
|
||
| if softcap is not None: | ||
| assert score_mod is None, "softcap and score_mod cannot be used together" | ||
| score_mod = utils.create_softcap_scoremod(softcap) | ||
|
|
||
| is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None | ||
| use_block_sparsity = full_block_cnt is not None or mask_block_cnt is not None | ||
| if score_mod is not None: | ||
| is_varlen = cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None | ||
| if is_varlen: | ||
| raise NotImplementedError("score_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.") | ||
| if pack_gqa: | ||
| raise NotImplementedError("score_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a rebase bug, I added support here: #1937 |
||
|
|
||
| if mask_mod is not None: | ||
| if not use_block_sparsity: | ||
| raise NotImplementedError("mask_mod requires the use of block sparsity. This will be fixed in a future PR.") | ||
| if is_varlen: | ||
| raise NotImplementedError("mask_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR.") | ||
| if pack_gqa: | ||
| raise NotImplementedError("mask_mod with buffers is not yet supported with pack_gqa=True. This will be fixed in a future PR.") | ||
|
|
||
| if use_block_sparsity: | ||
| if is_varlen: | ||
| raise NotImplementedError("Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR.") | ||
| if pack_gqa: | ||
| raise NotImplementedError("Block sparsity is not yet supported with pack_gqa=True. This will be fixed in a future PR.") | ||
|
|
||
| cute_buffers = None | ||
| if buffers is not None: | ||
| cute_buffers = [from_dlpack(buf) for buf in buffers] | ||
|
|
||
| compile_key = ( | ||
| dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, utils.hash_callable(score_mod) if score_mod is not None else None, | ||
| dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, | ||
| score_mod_hash, mask_mod_hash, | ||
| buffers is not None, | ||
| lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, | ||
| page_table is not None, | ||
|
|
@@ -245,6 +295,9 @@ def _flash_attn_fwd( | |
| num_stages=2, | ||
| num_threads=num_threads, | ||
| Q_in_regs=False, | ||
| intra_wg_overlap=True, | ||
| mma_pv_is_rs=True, | ||
| mask_mod=mask_mod, | ||
| score_mod=score_mod, | ||
| has_buffers=buffers is not None, | ||
| ) | ||
|
|
@@ -264,18 +317,21 @@ def _flash_attn_fwd( | |
| else: | ||
| raise ValueError(f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x") | ||
| # TODO: check @can_implement | ||
| # TODO caching for buffers; cute_buffers | ||
| _flash_attn_fwd.compile_cache[compile_key] = cute.compile( | ||
| fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, | ||
| cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, | ||
| page_table_tensor, | ||
| window_size_left, window_size_right, learnable_sink_tensor, cute_buffers, | ||
| window_size_left, window_size_right, learnable_sink_tensor, | ||
| full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor, | ||
| cute_buffers, | ||
| ) | ||
| _flash_attn_fwd.compile_cache[compile_key]( | ||
| q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, | ||
| cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, | ||
| page_table_tensor, | ||
| window_size_left, window_size_right, learnable_sink_tensor, cute_buffers | ||
| window_size_left, window_size_right, learnable_sink_tensor, | ||
| full_block_cnt_tensor, full_block_idx_tensor, mask_block_cnt_tensor, mask_block_idx_tensor, | ||
| cute_buffers, | ||
| ) | ||
| return out, lse | ||
|
|
||
|
|
@@ -591,6 +647,11 @@ def forward( | |
| learnable_sink: Optional[torch.Tensor] = None, | ||
| softcap: float = 0.0, | ||
| pack_gqa: Optional[bool] = None, | ||
| mask_mod: Optional[Callable] = None, | ||
| full_block_cnt: Optional[torch.Tensor] = None, | ||
| full_block_idx: Optional[torch.Tensor] = None, | ||
| mask_block_cnt: Optional[torch.Tensor] = None, | ||
| mask_block_idx: Optional[torch.Tensor] = None, | ||
| ): | ||
| out, lse = _flash_attn_fwd( | ||
| q, | ||
|
|
@@ -603,6 +664,11 @@ def forward( | |
| learnable_sink=learnable_sink, | ||
| softcap=softcap, | ||
| pack_gqa=pack_gqa, | ||
| mask_mod=mask_mod, | ||
| full_block_cnt=full_block_cnt, | ||
| full_block_idx=full_block_idx, | ||
| mask_block_cnt=mask_block_cnt, | ||
| mask_block_idx=mask_block_idx, | ||
| ) | ||
| ctx.save_for_backward(q, k, v, out, lse) | ||
| ctx.softmax_scale = softmax_scale | ||
|
|
@@ -706,6 +772,11 @@ def flash_attn_func( | |
| learnable_sink: Optional[torch.Tensor] = None, | ||
| softcap: float = 0.0, | ||
| pack_gqa: Optional[bool] = None, | ||
| mask_mod: Optional[Callable] = None, | ||
| full_block_cnt: Optional[torch.Tensor] = None, | ||
| full_block_idx: Optional[torch.Tensor] = None, | ||
| mask_block_cnt: Optional[torch.Tensor] = None, | ||
| mask_block_idx: Optional[torch.Tensor] = None, | ||
| ): | ||
| return FlashAttnFunc.apply( | ||
| q, | ||
|
|
@@ -717,6 +788,11 @@ def flash_attn_func( | |
| learnable_sink, | ||
| softcap, | ||
| pack_gqa, | ||
| mask_mod, | ||
| full_block_cnt, | ||
| full_block_idx, | ||
| mask_block_cnt, | ||
| mask_block_idx, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -973,4 +1049,4 @@ def flash_attn_combine( | |
| lse = None | ||
|
|
||
| _flash_attn_fwd_combine(out_partial, lse_partial, out, lse) | ||
| return out, lse | ||
| return out, lse | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some postfacto comments feel free to ignore:
can we maybe put full and mask cnts/indices ina a tuple so that its easier to pass around