Skip to content

Wrap mask contruction in a function for mask subclassing#2584

Merged
drisspg merged 1 commit into
Dao-AILab:mainfrom
sryap:mask-subclass
May 22, 2026
Merged

Wrap mask contruction in a function for mask subclassing#2584
drisspg merged 1 commit into
Dao-AILab:mainfrom
sryap:mask-subclass

Conversation

@sryap
Copy link
Copy Markdown
Contributor

@sryap sryap commented May 22, 2026

Summary:

Extract the inline AttentionMask construction in FlashAttentionForwardSm100 and FlashAttentionBackwardSm100 into an overridable _generate_attention_mask_cls method. This allows subclasses to inject a custom AttentionMask without modifying the base kernel code.

For example, a custom attention kernel can override the mask to add a causal_q_divisor field for scaling the row_idx value.

class CustomAttentionMask(AttentionMask):
    causal_q_divisor: cutlass.Constexpr[int] = 1

    @cute.jit
    def apply_mask_sm100(self, acc_S, m_block, n_block, ...):
        # Custom causal logic using causal_q_divisor
        row_idx = (tScS_t2r[0][0] + m_block * self.tile_m) // self.causal_q_divisor
        ...

class CustomFlashAttentionForwardSm100(FlashAttentionForwardSm100):
    def __init__(self, *args, causal_q_divisor=1, **kwargs):
        super().__init__(*args, **kwargs)
        self.causal_q_divisor = causal_q_divisor

    def _generate_attention_mask_cls(self, window_size_left, window_size_right):
        return partial(
            CustomAttentionMask,
            self.m_block_size,
            self.n_block_size,
            window_size_left=window_size_left,
            window_size_right=window_size_right,
            bottom_right=self.is_bottom_right,
            causal_q_divisor=self.causal_q_divisor,
        )

Test Plan:

$ pytest tests/cute/test_flash_attn_fast.py -v

================ 240 passed, 4139 warnings in 984.24s (0:16:24) ================

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

Extract the inline `AttentionMask` construction in
`FlashAttentionForwardSm100` and `FlashAttentionBackwardSm100` into an
overridable `_generate_attention_mask_cls` method. This allows
subclasses to inject a custom `AttentionMask` without modifying the
base kernel code.

For example, a custom attention kernel can override the mask to add a
`causal_q_divisor` field for scaling the `row_idx` value.

```
class CustomAttentionMask(AttentionMask):
    causal_q_divisor: cutlass.Constexpr[int] = 1

    @cute.jit
    def apply_mask_sm100(self, acc_S, m_block, n_block, ...):
        # Custom causal logic using causal_q_divisor
        row_idx = (tScS_t2r[0][0] + m_block * self.tile_m) // self.causal_q_divisor
        ...

class CustomFlashAttentionForwardSm100(FlashAttentionForwardSm100):
    def __init__(self, *args, causal_q_divisor=1, **kwargs):
        super().__init__(*args, **kwargs)
        self.causal_q_divisor = causal_q_divisor

    def _generate_attention_mask_cls(self, window_size_left, window_size_right):
        return partial(
            CustomAttentionMask,
            self.m_block_size,
            self.n_block_size,
            window_size_left=window_size_left,
            window_size_right=window_size_right,
            bottom_right=self.is_bottom_right,
            causal_q_divisor=self.causal_q_divisor,
        )
```

Test Plan:

```
$ pytest tests/cute/test_flash_attn_fast.py -v

================ 240 passed, 4139 warnings in 984.24s (0:16:24) ================
```

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link
Copy Markdown
Collaborator

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@drisspg drisspg merged commit 0cb66b4 into Dao-AILab:main May 22, 2026
@sryap sryap deleted the mask-subclass branch May 22, 2026 22:49
reubenconducts pushed a commit to reubenconducts/flash-attention that referenced this pull request Jun 2, 2026
)

Summary:

Extract the inline `AttentionMask` construction in
`FlashAttentionForwardSm100` and `FlashAttentionBackwardSm100` into an
overridable `_generate_attention_mask_cls` method. This allows
subclasses to inject a custom `AttentionMask` without modifying the
base kernel code.

For example, a custom attention kernel can override the mask to add a
`causal_q_divisor` field for scaling the `row_idx` value.

```
class CustomAttentionMask(AttentionMask):
    causal_q_divisor: cutlass.Constexpr[int] = 1

    @cute.jit
    def apply_mask_sm100(self, acc_S, m_block, n_block, ...):
        # Custom causal logic using causal_q_divisor
        row_idx = (tScS_t2r[0][0] + m_block * self.tile_m) // self.causal_q_divisor
        ...

class CustomFlashAttentionForwardSm100(FlashAttentionForwardSm100):
    def __init__(self, *args, causal_q_divisor=1, **kwargs):
        super().__init__(*args, **kwargs)
        self.causal_q_divisor = causal_q_divisor

    def _generate_attention_mask_cls(self, window_size_left, window_size_right):
        return partial(
            CustomAttentionMask,
            self.m_block_size,
            self.n_block_size,
            window_size_left=window_size_left,
            window_size_right=window_size_right,
            bottom_right=self.is_bottom_right,
            causal_q_divisor=self.causal_q_divisor,
        )
```

Test Plan:

```
$ pytest tests/cute/test_flash_attn_fast.py -v

================ 240 passed, 4139 warnings in 984.24s (0:16:24) ================
```

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants