Wrap mask contruction in a function for mask subclassing#2584
Merged
Conversation
Summary:
Extract the inline `AttentionMask` construction in
`FlashAttentionForwardSm100` and `FlashAttentionBackwardSm100` into an
overridable `_generate_attention_mask_cls` method. This allows
subclasses to inject a custom `AttentionMask` without modifying the
base kernel code.
For example, a custom attention kernel can override the mask to add a
`causal_q_divisor` field for scaling the `row_idx` value.
```
class CustomAttentionMask(AttentionMask):
causal_q_divisor: cutlass.Constexpr[int] = 1
@cute.jit
def apply_mask_sm100(self, acc_S, m_block, n_block, ...):
# Custom causal logic using causal_q_divisor
row_idx = (tScS_t2r[0][0] + m_block * self.tile_m) // self.causal_q_divisor
...
class CustomFlashAttentionForwardSm100(FlashAttentionForwardSm100):
def __init__(self, *args, causal_q_divisor=1, **kwargs):
super().__init__(*args, **kwargs)
self.causal_q_divisor = causal_q_divisor
def _generate_attention_mask_cls(self, window_size_left, window_size_right):
return partial(
CustomAttentionMask,
self.m_block_size,
self.n_block_size,
window_size_left=window_size_left,
window_size_right=window_size_right,
bottom_right=self.is_bottom_right,
causal_q_divisor=self.causal_q_divisor,
)
```
Test Plan:
```
$ pytest tests/cute/test_flash_attn_fast.py -v
================ 240 passed, 4139 warnings in 984.24s (0:16:24) ================
```
Reviewers:
Subscribers:
Tasks:
Tags:
reubenconducts
pushed a commit
to reubenconducts/flash-attention
that referenced
this pull request
Jun 2, 2026
) Summary: Extract the inline `AttentionMask` construction in `FlashAttentionForwardSm100` and `FlashAttentionBackwardSm100` into an overridable `_generate_attention_mask_cls` method. This allows subclasses to inject a custom `AttentionMask` without modifying the base kernel code. For example, a custom attention kernel can override the mask to add a `causal_q_divisor` field for scaling the `row_idx` value. ``` class CustomAttentionMask(AttentionMask): causal_q_divisor: cutlass.Constexpr[int] = 1 @cute.jit def apply_mask_sm100(self, acc_S, m_block, n_block, ...): # Custom causal logic using causal_q_divisor row_idx = (tScS_t2r[0][0] + m_block * self.tile_m) // self.causal_q_divisor ... class CustomFlashAttentionForwardSm100(FlashAttentionForwardSm100): def __init__(self, *args, causal_q_divisor=1, **kwargs): super().__init__(*args, **kwargs) self.causal_q_divisor = causal_q_divisor def _generate_attention_mask_cls(self, window_size_left, window_size_right): return partial( CustomAttentionMask, self.m_block_size, self.n_block_size, window_size_left=window_size_left, window_size_right=window_size_right, bottom_right=self.is_bottom_right, causal_q_divisor=self.causal_q_divisor, ) ``` Test Plan: ``` $ pytest tests/cute/test_flash_attn_fast.py -v ================ 240 passed, 4139 warnings in 984.24s (0:16:24) ================ ``` Reviewers: Subscribers: Tasks: Tags:
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Extract the inline
AttentionMaskconstruction inFlashAttentionForwardSm100andFlashAttentionBackwardSm100into an overridable_generate_attention_mask_clsmethod. This allows subclasses to inject a customAttentionMaskwithout modifying the base kernel code.For example, a custom attention kernel can override the mask to add a
causal_q_divisorfield for scaling therow_idxvalue.Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: