[Cute,Flex,Sm100] vectorized mask_mod#2261
Merged
Merged
Conversation
drisspg
reviewed
Feb 28, 2026
drisspg
reviewed
Feb 28, 2026
Collaborator
drisspg
left a comment
There was a problem hiding this comment.
couple clarifying questions but this looks good, I just put up autotuning PR: pytorch/pytorch#176055
helps alot in some cases
ca1affb to
c8614b2
Compare
drisspg
approved these changes
May 23, 2026
pytorchmergebot
pushed a commit
to pytorch/pytorch
that referenced
this pull request
Jun 1, 2026
# Summary broke the top PR up into 2 to try and make this more managable. In this pr we basically do 1 main thing. In a recent pr we enabled the ability to invoke mask-mod functions in a vectorized form. Dao-AILab/flash-attention#2261. This allows us to use r2p instructions for masking and opens up the possibility for doing vectorized loads from aux tensors. A big bulk of this code is doing code motion and more typing to get the analysis tools in their own file: torch/_inductor/kernel/flex/aux_vectorization.py This is mostly sympy shenangins and math to see if we are in 1 of three cases. 1. we are lane uniform: You are loading kv 0,1,2,3... and your expresions only depdnes on q which we know from how we added the vecotorization hook the q value will be consistent across the full span. 2. you are conitugous up to some vec width. lets say you start with positive indices and your expression is `KV%4`. Since kv starts at 0 (for this expression) we know that you can can do vector operations with vecwidth of 4 3. every lane is loading from non contiguous locations and needs to be gather. `kv*2`. That is basicallly it. I did heaps of fuzz testing I feel confident that we are conservative where we need to be. We also needed to update the codegen in flex/flash for loads to be able to render the right autovec copy. I think the main point where review is needed is in this one file since this is the bulk of new behavior. Perhaps if you are more familiar with sympy expression analysis and have better ways to structure things there that would be helpful. And maybe the codegen stuff but I feel good there Pull Request resolved: #185020 Approved by: https://github.com/eellison
reubenconducts
added a commit
to reubenconducts/flash-attention
that referenced
this pull request
Jun 2, 2026
* vectorized mask mod application for existing mask mod signatures * add vectorized mask mod examples, get vectorized evaluation and application working * guard sm80/90/120 against mask_vec_size > 2 * thread mask_vec_size thru sm80/90/120 kernel * Small tweaks coverign sm90 * Small tweaks coverign sm90 --------- Co-authored-by: drisspg <drisspguessous@gmail.com>
pytorchmergebot
pushed a commit
to pytorch/pytorch
that referenced
this pull request
Jun 2, 2026
This is alot of code huh.. welcome to 2026 ;). # Summary Why driss why. Let me tell you. Thanks to Rebuen we now have a nice vectorized maskmod support once we land: Dao-AILab/flash-attention#2261 and our codgen needs to update in order to take advantage of this! I have started landing pieces of this: #183406 for isntance which has huge wins for loading from tensors that are contiguous along the kv dim. This lets us acutally do wider vectorized loads and greatly improves perf. We want to do the same for mask mod loads. And that is 1 of the two things this PR does. From benchmarking I found that while we do see 128/256 bit width loads when expected and we can get ~ 20% perf bump we are still slower than reproted in that PR. And the main gap is that for contiguous spans `PackedMaskAnalyzer` in this pr. basically for things like causal prefixlm/ sliding window where the masking a contiguous span with lo and high bounds we can do much more optimized code using R2P and baking predication into UINT32 Non zero amount of code for q-uniform matters : For packed mask lowering, bounds may depend on batch, head, query index, or captured aux tensors, but they must be uniform across the vectorized KV lanes. I.e. They can depend on things that are the same for the whole 32-lane group: - batch index b - head index h - query index q_idx - the group start kv_idx - captured tensor loads indexed only by q-uniform values, e.g. offsets[doc_id[b, q_idx]] But they cannot depend on lane in an arbitrary way. We (mostly codex) did alot of fuzz testing to determine a sane way to filter out patterns that dont fit mostly by reusing utils form our vec load utils where feasible. This is so much code, and I hate it, but it works. I'm still trying to trim it down/make things more grokkable for me. Yes, I have read every line but still. I have not read every line of the tests, and this has mostly been FuzzTest-driven. I launched 25 subagents to fuzz test this code, and it added these unit tests. I'm going to go through and see if I can simplify them. ## Perf | shape | case | sparsity | old us | auto us | speedup | auto TFLOPs | mask vec | autovec | shift/R2P | lane pack | |------------|-------------------|------------|----------|-----------|-----------|---------------|------------|-----------|-------------|-------------| | llama8b_8k | causal | 0.484 | 460.91 | 461.6 | 0.999 | 1233 | True | 0 | True | False | | llama8b_8k | document_ids | 0.949 | 225.2 | 166.54 | 1.352 | 336.6 | True | 1 | False | True | | llama8b_8k | document_offsets | 0.895 | 317.18 | 157.22 | 2.018 | 737.1 | True | 0 | True | False | | llama8b_8k | gather_only | 0.613 | 1921.7 | 1924.16 | 0.999 | 221.8 | False | 0 | False | False | | llama8b_8k | mixed_gather_tail | 0 | 12324.1 | 3485.06 | 3.536 | 316.7 | True | 2 | False | True | | llama8b_8k | prefix_lm | 0.457 | 490.93 | 455.2 | 1.078 | 1316.6 | True | 0 | True | False | | llama8b_8k | qkv_bias | 0 | 6310.77 | 2123.01 | 2.973 | 519.9 | True | 1 | False | True | | llama8b_8k | rank1_kv | 0.484 | 1358.93 | 1360.82 | 0.999 | 418.2 | False | 0 | False | False | | llama8b_8k | sliding_window | 0.938 | 163.94 | 92.34 | 1.775 | 735.5 | True | 0 | True | False | ### Codegen updates for those curious ## qkv_bias old scalar ```python def mask_mod(_b, _h, q_idx, kv_idx): return (q_idx >= kv_idx) | (bias[q_idx, kv_idx] > 0) ``` ```python @cute.jit def mask_mod(b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): in_ptr8 = aux_tensors[0] tmp0 = q_idx tmp1 = kv_idx tmp2 = operator.ge(tmp0, tmp1) tmp3 = ssa_to_fragment(q_idx, cutlass.Int32) tmp4 = ssa_to_fragment(kv_idx, cutlass.Int32) tmp5 = cute.make_rmem_tensor(cute.size(tmp3.shape), cutlass.BFloat16) for load_idx in cutlass.range(cute.size(tmp5.shape), unroll_full=True): tmp5[load_idx] = (in_ptr8[tmp3[load_idx], tmp4[load_idx]]) tmp6 = (tmp5.load()).to(cutlass.Float32) tmp7 = operator.gt(tmp6, cute.full_like(tmp6, 0.0)) tmp8 = (tmp2 | tmp7) mask_mod_output = tmp8 return mask_mod_output mask_mod.__cute_hash__ = "cutedsl_fused_flex_attention_15852c9b_mask" ``` ## qkv_bias auto vec32 ```python @cute.jit def mask_mod(b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): in_ptr8 = aux_tensors[0] tmp0 = q_idx tmp1 = kv_idx tmp2 = operator.ge(tmp0, tmp1) tmp3 = ssa_to_fragment(q_idx, cutlass.Int32) tmp4 = ssa_to_fragment(kv_idx, cutlass.Int32) tmp5 = cute.make_rmem_tensor(cute.size(tmp3.shape), cutlass.BFloat16) tmp6 = cute.assume(tmp4[0], divby=cute.size(tmp5.shape)) tmp7 = cute.local_tile(in_ptr8, (1, cute.size(tmp5.shape)), (tmp3[0], tmp6 // cute.size(tmp5.shape))) tmp8 = cute.make_ptr(cutlass.BFloat16, tmp7.iterator.toint(), tmp7.iterator.memspace, assumed_align=min(16, cute.size(tmp5.shape) * 2)) tmp9 = cute.make_tensor(tmp8, tmp7.layout) cute.autovec_copy(tmp9[0, None], tmp5) tmp10 = (tmp5.load()).to(cutlass.Float32) tmp11 = operator.gt(tmp10, cute.full_like(tmp10, 0.0)) tmp12 = (tmp2 | tmp11) mask_mod_output = tmp12 mask_mod_packed = cute.make_rmem_tensor(1, dtype=cutlass.Uint32) mask_mod_packed[0] = cutlass.Uint32(0) for mask_lane_idx in cutlass.range_constexpr(32): mask_bit = cutlass.Uint32(1) << mask_lane_idx mask_mod_packed[0] = ( mask_mod_packed[0] | mask_bit if cutlass.Boolean(mask_mod_output[mask_lane_idx]) else mask_mod_packed[0] ) return mask_mod_packed.load() mask_mod.__cute_hash__ = "cutedsl_fused_flex_attention_78e3f466_mask" mask_mod.__mask_vec_size__ = 32 mask_mod.__vec_size__ = 32 ``` ## sliding_window old scalar ```python @cute.jit def mask_mod(b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): tmp0 = q_idx tmp1 = kv_idx tmp2 = operator.ge(tmp0, tmp1) tmp3 = (tmp0 - tmp1) tmp4 = operator.le(tmp3, cute.full_like(tmp3, 256)) tmp5 = (tmp2 & tmp4) mask_mod_output = tmp5 return mask_mod_output mask_mod.__cute_hash__ = "cutedsl_fused_flex_attention_d776a5fa_mask" ``` ## sliding_window auto packed interval ```python @cute.jit def mask_mod(b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): mask_mod_packed = cute.make_rmem_tensor(1, dtype=cutlass.Uint32) mask_mod_packed[0] = cutlass.Uint32(0) interval_lower_0 = max(cutlass.Int32(0), (cutlass.Int32(-256) + q_idx[0] + (cutlass.Int32(-1) * kv_idx[0]))) interval_upper_0 = min((cutlass.Int32(1) + q_idx[0] + (cutlass.Int32(-1) * kv_idx[0])), cutlass.Int32(32)) below_0 = utils.shr_u32( cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32( min( max( cutlass.Int32(32) - interval_upper_0, cutlass.Int32(0), ), cutlass.Int32(32), ) ), ) above_0 = utils.shl_u32( cutlass.Uint32(0xFFFFFFFF), cutlass.Uint32( min( max(interval_lower_0, cutlass.Int32(0)), cutlass.Int32(32), ) ), ) mask_mod_packed[0] = mask_mod_packed[0] | ( below_0 & above_0 ) return mask_mod_packed.load() mask_mod.__cute_hash__ = "cutedsl_fused_flex_attention_7648c49a_mask" mask_mod.__mask_vec_size__ = 32 mask_mod.__vec_size__ = 32 ``` Pull Request resolved: #184438 Approved by: https://github.com/eellison
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Follow-up to #2236. The approach to vectorizing is bipartite:
r2pThe latter is important for example in situations where
mask_moddepends onaux_tensorsthat are contiguous in the kv idx, or whenaux_tensorsdon't depend on kv index at all.mask_mods still emitTensorSSAs, but they need not be single values. These are treated as bit-packed masks.Note: this current work is Sm100 only.
See
mask_mod_definitions.pyfor many examples. Vectorization leads to a speedup in all relevant mask mods:I added tests checking bitwise equality between ordinary and vectorized paths; those and existing mask mod tests all pass (


test_mask_modandtest_mask_mod_varlen, respectively):cc @drisspg @v0i0