Skip to content

[Cute,Flex,Sm100] vectorized mask_mod#2261

Merged
drisspg merged 6 commits into
Dao-AILab:mainfrom
reubenconducts:rstern/mask-vec
May 24, 2026
Merged

[Cute,Flex,Sm100] vectorized mask_mod#2261
drisspg merged 6 commits into
Dao-AILab:mainfrom
reubenconducts:rstern/mask-vec

Conversation

@reubenconducts
Copy link
Copy Markdown
Contributor

@reubenconducts reubenconducts commented Feb 17, 2026

Follow-up to #2236. The approach to vectorizing is bipartite:

  • Vectorize mask application, to compile down to r2p
  • Vectorize mask evaluation

The latter is important for example in situations where mask_mod depends on aux_tensors that are contiguous in the kv idx, or when aux_tensors don't depend on kv index at all.

mask_mods still emit TensorSSAs, but they need not be single values. These are treated as bit-packed masks.

Note: this current work is Sm100 only.

See mask_mod_definitions.py for many examples. Vectorization leads to a speedup in all relevant mask mods:

  ┌───────────────────────────────────┬────────────┬────────────────────┬────────────────────┬─────────────────────┬──────────────┐
  │               mask                │   (b, s)   │ scalar ms / TFLOPS │ vec_32 ms / TFLOPS │ vec_128 ms / TFLOPS │ best speedup │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ causal                            │ (8, 2048)  │ 0.107 / 644        │ 0.100 / 690        │ —                   │ +7%          │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ causal                            │ (4, 4096)  │ 0.159 / 862        │ 0.153 / 898        │ —                   │ +4%          │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ causal                            │ (2, 8192)  │ 0.266 / 1035       │ 0.262 / 1049       │ —                   │ +1%          │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ causal                            │ (1, 16384) │ 0.485 / 1135       │ 0.481 / 1143       │ —                   │ +1%          │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ block_causal                      │ (8, 2048)  │ 0.106 / 648        │ 0.100 / 691        │ —                   │ +6%          │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ block_causal                      │ (1, 16384) │ 0.484 / 1136       │ 0.481 / 1144       │ —                   │ +1%          │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ sliding_window 256                │ (8, 2048)  │ 0.121 / 267        │ 0.077 / 419        │ —                   │ +57%         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ sliding_window 256                │ (4, 4096)  │ 0.125 / 267        │ 0.079 / 421        │ —                   │ +58%         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ sliding_window 256                │ (2, 8192)  │ 0.127 / 268        │ 0.080 / 425        │ —                   │ +59%         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ sliding_window 256                │ (1, 16384) │ 0.128 / 268        │ 0.080 / 425        │ —                   │ +60%         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ sliding_with_sink sink=64 win=256 │ (8, 2048)  │ 0.108 / 333        │ 0.082 / 438        │ —                   │ +32%         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ sliding_with_sink sink=64 win=256 │ (4, 4096)  │ 0.108 / 344        │ 0.085 / 438        │ —                   │ +27%         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ sliding_with_sink sink=64 win=256 │ (2, 8192)  │ 0.109 / 348        │ 0.087 / 439        │ —                   │ +26%         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ sliding_with_sink sink=64 win=256 │ (1, 16384) │ 0.111 / 345        │ 0.088 / 438        │ —                   │ +26%         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ block_diagonal 128                │ (8, 2048)  │ 0.072 / 120        │ 0.052 / 164        │ —                   │ +38%         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ block_diagonal 128                │ (4, 4096)  │ 0.073 / 118        │ 0.053 / 161        │ —                   │ +38%         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ block_diagonal 128                │ (2, 8192)  │ 0.072 / 119        │ 0.053 / 163        │ —                   │ +36%         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ block_diagonal 128                │ (1, 16384) │ 0.073 / 118        │ 0.053 / 162        │ —                   │ +38%         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ prefix_lm 512                     │ (8, 2048)  │ 0.113 / 646        │ 0.100 / 729        │ —                   │ +13%         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ prefix_lm 512                     │ (4, 4096)  │ 0.167 / 836        │ 0.155 / 904        │ —                   │ +8%          │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ prefix_lm 512                     │ (1, 16384) │ 0.492 / 1120       │ 0.482 / 1142       │ —                   │ +2%          │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ document                          │ (8, 2048)  │ 0.229 / 74         │ 0.118 / 143        │ 0.109 / 155         │ 2.1×         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ document                          │ (4, 4096)  │ 0.266 / 166        │ 0.152 / 291        │ 0.137 / 323         │ 1.9×         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ document                          │ (2, 8192)  │ 0.368 / 206        │ 0.247 / 307        │ 0.216 / 350         │ 1.7×         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ document                          │ (1, 16384) │ 0.264 / 139        │ 0.150 / 245        │ 0.135 / 273         │ 2.0×         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ packed_aux                        │ (8, 2048)  │ 1.425 / 48         │ 0.212 / 324        │ 0.160 / 430         │ 8.9×         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ packed_aux                        │ (4, 4096)  │ 2.780 / 49         │ 0.393 / 350        │ 0.284 / 483         │ 9.8×         │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ packed_aux                        │ (2, 8192)  │ 5.490 / 50         │ 0.755 / 364        │ 0.535 / 513         │ 10.3×        │
  ├───────────────────────────────────┼────────────┼────────────────────┼────────────────────┼─────────────────────┼──────────────┤
  │ packed_aux                        │ (1, 16384) │ 10.937 / 50        │ 1.479 / 372        │ 1.056 / 521         │ 10.4×        │
  └───────────────────────────────────┴────────────┴────────────────────┴────────────────────┴─────────────────────┴──────────────┘

I added tests checking bitwise equality between ordinary and vectorized paths; those and existing mask mod tests all pass (test_mask_mod and test_mask_mod_varlen, respectively):
Screenshot 2026-05-11 at 9 58 58 PM
Screenshot 2026-05-11 at 10 13 33 PM

cc @drisspg @v0i0

Comment thread flash_attn/cute/mask.py Outdated
Copy link
Copy Markdown
Collaborator

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couple clarifying questions but this looks good, I just put up autotuning PR: pytorch/pytorch#176055

helps alot in some cases

@reubenconducts reubenconducts changed the title [WIP,Cute,Flex,Sm100] vectorized mask mod application [Cute,Flex,Sm100] vectorized mask_mod May 12, 2026
@reubenconducts reubenconducts marked this pull request as ready for review May 12, 2026 02:21
@drisspg drisspg merged commit fe5fb1b into Dao-AILab:main May 24, 2026
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jun 1, 2026
# Summary

broke the top PR up into 2 to try and make this more managable.

In this pr we basically do 1 main thing. In a recent pr we enabled the ability to invoke mask-mod functions in a vectorized form. Dao-AILab/flash-attention#2261. This allows us to use r2p instructions for masking and opens up the possibility for doing vectorized loads from aux tensors.

A big bulk of this code is doing code motion and more typing to get the analysis tools in their own file: torch/_inductor/kernel/flex/aux_vectorization.py

This is mostly sympy shenangins and math to see if we are in 1 of three cases.
1. we are lane uniform: You are loading kv 0,1,2,3... and your expresions only depdnes on q which we know from how we added the vecotorization hook the q value will be consistent across the full span.
2. you are conitugous up to some vec width. lets say you start with positive indices and your expression is `KV%4`. Since kv starts at 0 (for this expression) we know that you can can do vector operations with vecwidth of 4
3. every lane is loading from non contiguous locations and needs to be gather. `kv*2`.

That is basicallly it. I did heaps of fuzz testing I feel confident that we are conservative where we need to be.  We also needed to update the codegen in flex/flash for loads to be able to render the right autovec copy. I think the main point where review is needed is in this one file since this is the bulk of new behavior. Perhaps if you are more familiar with sympy expression analysis and have better ways to structure things there that would be helpful. And maybe the codegen stuff but I feel good there

Pull Request resolved: #185020
Approved by: https://github.com/eellison
reubenconducts added a commit to reubenconducts/flash-attention that referenced this pull request Jun 2, 2026
* vectorized mask mod application for existing mask mod signatures

* add vectorized mask mod examples, get vectorized evaluation and application working

* guard sm80/90/120 against mask_vec_size > 2

* thread mask_vec_size thru sm80/90/120 kernel

* Small tweaks coverign sm90

* Small tweaks coverign sm90

---------

Co-authored-by: drisspg <drisspguessous@gmail.com>
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jun 2, 2026
This is alot of code huh.. welcome to 2026 ;).

# Summary

Why driss why. Let me tell you. Thanks to Rebuen we now have a nice vectorized maskmod support once we land: Dao-AILab/flash-attention#2261 and our codgen needs to update in order to take advantage of this!

I have started landing pieces of this: #183406 for isntance which has huge wins for loading from tensors that are contiguous along the kv dim. This lets us acutally do wider vectorized loads and greatly improves perf.

We want to do the same for mask mod loads. And that is 1 of the two things this PR does. From benchmarking I found that while we do see 128/256 bit width loads when expected and we can get ~ 20% perf bump we are still slower than reproted in that PR. And the main gap is that for contiguous spans `PackedMaskAnalyzer` in this pr. basically for things like causal prefixlm/ sliding window where the masking a contiguous span  with lo and high bounds we can do much more optimized code using R2P and baking predication into UINT32

Non zero amount of code for  q-uniform matters :
For packed mask lowering, bounds may depend on batch, head, query index, or captured aux tensors, but they must be uniform across the vectorized KV lanes. I.e.

They can depend on things that are the same for the whole 32-lane group:

 - batch index b
 - head index h
 - query index q_idx
 - the group start kv_idx
 - captured tensor loads indexed only by q-uniform values, e.g. offsets[doc_id[b, q_idx]]

 But they cannot depend on lane in an arbitrary way.

We (mostly codex) did alot of fuzz testing to determine a sane way to filter out patterns that dont fit mostly by reusing utils form our vec load utils where feasible.

This is so much code, and I hate it, but it works. I'm still trying to trim it down/make things more grokkable for me. Yes, I have read every line but still. I have not read every line of the tests, and this has mostly been FuzzTest-driven. I launched 25 subagents to fuzz test this code, and it added these unit tests. I'm going to go through and see if I can simplify them.

## Perf

| shape      | case              |   sparsity |   old us |   auto us |   speedup |   auto TFLOPs | mask vec   |   autovec | shift/R2P   | lane pack   |
|------------|-------------------|------------|----------|-----------|-----------|---------------|------------|-----------|-------------|-------------|
| llama8b_8k | causal            |      0.484 |   460.91 |    461.6  |     0.999 |        1233   | True       |         0 | True        | False       |
| llama8b_8k | document_ids      |      0.949 |   225.2  |    166.54 |     1.352 |         336.6 | True       |         1 | False       | True        |
| llama8b_8k | document_offsets  |      0.895 |   317.18 |    157.22 |     2.018 |         737.1 | True       |         0 | True        | False       |
| llama8b_8k | gather_only       |      0.613 |  1921.7  |   1924.16 |     0.999 |         221.8 | False      |         0 | False       | False       |
| llama8b_8k | mixed_gather_tail |      0     | 12324.1  |   3485.06 |     3.536 |         316.7 | True       |         2 | False       | True        |
| llama8b_8k | prefix_lm         |      0.457 |   490.93 |    455.2  |     1.078 |        1316.6 | True       |         0 | True        | False       |
| llama8b_8k | qkv_bias          |      0     |  6310.77 |   2123.01 |     2.973 |         519.9 | True       |         1 | False       | True        |
| llama8b_8k | rank1_kv          |      0.484 |  1358.93 |   1360.82 |     0.999 |         418.2 | False      |         0 | False       | False       |
| llama8b_8k | sliding_window    |      0.938 |   163.94 |     92.34 |     1.775 |         735.5 | True       |         0 | True        | False       |

### Codegen updates for those curious

## qkv_bias old scalar

```python
def mask_mod(_b, _h, q_idx, kv_idx):
            return (q_idx >= kv_idx) | (bias[q_idx, kv_idx] > 0)
```

```python
    @cute.jit
    def mask_mod(b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):

        in_ptr8 = aux_tensors[0]

        tmp0 = q_idx
        tmp1 = kv_idx
        tmp2 = operator.ge(tmp0, tmp1)
        tmp3 = ssa_to_fragment(q_idx, cutlass.Int32)
        tmp4 = ssa_to_fragment(kv_idx, cutlass.Int32)
        tmp5 = cute.make_rmem_tensor(cute.size(tmp3.shape), cutlass.BFloat16)
        for load_idx in cutlass.range(cute.size(tmp5.shape), unroll_full=True):
            tmp5[load_idx] = (in_ptr8[tmp3[load_idx], tmp4[load_idx]])
        tmp6 = (tmp5.load()).to(cutlass.Float32)
        tmp7 = operator.gt(tmp6, cute.full_like(tmp6, 0.0))
        tmp8 = (tmp2 | tmp7)
        mask_mod_output = tmp8

        return mask_mod_output

    mask_mod.__cute_hash__ = "cutedsl_fused_flex_attention_15852c9b_mask"
```

## qkv_bias auto vec32

```python
    @cute.jit
    def mask_mod(b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):

        in_ptr8 = aux_tensors[0]

        tmp0 = q_idx
        tmp1 = kv_idx
        tmp2 = operator.ge(tmp0, tmp1)
        tmp3 = ssa_to_fragment(q_idx, cutlass.Int32)
        tmp4 = ssa_to_fragment(kv_idx, cutlass.Int32)
        tmp5 = cute.make_rmem_tensor(cute.size(tmp3.shape), cutlass.BFloat16)

        tmp6 = cute.assume(tmp4[0], divby=cute.size(tmp5.shape))
        tmp7 = cute.local_tile(in_ptr8, (1, cute.size(tmp5.shape)), (tmp3[0], tmp6 // cute.size(tmp5.shape)))
        tmp8 = cute.make_ptr(cutlass.BFloat16, tmp7.iterator.toint(), tmp7.iterator.memspace, assumed_align=min(16, cute.size(tmp5.shape) * 2))
        tmp9 = cute.make_tensor(tmp8, tmp7.layout)
        cute.autovec_copy(tmp9[0, None], tmp5)
        tmp10 = (tmp5.load()).to(cutlass.Float32)
        tmp11 = operator.gt(tmp10, cute.full_like(tmp10, 0.0))
        tmp12 = (tmp2 | tmp11)
        mask_mod_output = tmp12

        mask_mod_packed = cute.make_rmem_tensor(1, dtype=cutlass.Uint32)
        mask_mod_packed[0] = cutlass.Uint32(0)
        for mask_lane_idx in cutlass.range_constexpr(32):
            mask_bit = cutlass.Uint32(1) << mask_lane_idx
            mask_mod_packed[0] = (
                mask_mod_packed[0] | mask_bit
                if cutlass.Boolean(mask_mod_output[mask_lane_idx])
                else mask_mod_packed[0]
            )
        return mask_mod_packed.load()

    mask_mod.__cute_hash__ = "cutedsl_fused_flex_attention_78e3f466_mask"

    mask_mod.__mask_vec_size__ = 32
    mask_mod.__vec_size__ = 32
```

## sliding_window old scalar

```python
    @cute.jit
    def mask_mod(b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):

        tmp0 = q_idx
        tmp1 = kv_idx
        tmp2 = operator.ge(tmp0, tmp1)
        tmp3 = (tmp0 - tmp1)
        tmp4 = operator.le(tmp3, cute.full_like(tmp3, 256))
        tmp5 = (tmp2 & tmp4)
        mask_mod_output = tmp5

        return mask_mod_output

    mask_mod.__cute_hash__ = "cutedsl_fused_flex_attention_d776a5fa_mask"
```

## sliding_window auto packed interval

```python
    @cute.jit
    def mask_mod(b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):

        mask_mod_packed = cute.make_rmem_tensor(1, dtype=cutlass.Uint32)
        mask_mod_packed[0] = cutlass.Uint32(0)

        interval_lower_0 = max(cutlass.Int32(0), (cutlass.Int32(-256) + q_idx[0] + (cutlass.Int32(-1) * kv_idx[0])))
        interval_upper_0 = min((cutlass.Int32(1) + q_idx[0] + (cutlass.Int32(-1) * kv_idx[0])), cutlass.Int32(32))
        below_0 = utils.shr_u32(
            cutlass.Uint32(0xFFFFFFFF),
            cutlass.Uint32(
                min(
                    max(
                        cutlass.Int32(32) - interval_upper_0,
                        cutlass.Int32(0),
                    ),
                    cutlass.Int32(32),
                )
            ),
        )
        above_0 = utils.shl_u32(
            cutlass.Uint32(0xFFFFFFFF),
            cutlass.Uint32(
                min(
                    max(interval_lower_0, cutlass.Int32(0)),
                    cutlass.Int32(32),
                )
            ),
        )
        mask_mod_packed[0] = mask_mod_packed[0] | (
            below_0 & above_0
        )

        return mask_mod_packed.load()

    mask_mod.__cute_hash__ = "cutedsl_fused_flex_attention_7648c49a_mask"

    mask_mod.__mask_vec_size__ = 32
    mask_mod.__vec_size__ = 32
```

Pull Request resolved: #184438
Approved by: https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants