Skip to content

[Cute,Flex,Fwd] Allow vectorized score_mod definitions#2236

Merged
drisspg merged 10 commits into
Dao-AILab:mainfrom
reubenconducts:rstern/vec-mod
Feb 11, 2026
Merged

[Cute,Flex,Fwd] Allow vectorized score_mod definitions#2236
drisspg merged 10 commits into
Dao-AILab:mainfrom
reubenconducts:rstern/vec-mod

Conversation

@reubenconducts
Copy link
Copy Markdown
Contributor

[Reupload of #2215 with Sm90] This PR is a score_mod "power user" update that allows the user to specify vectorization for a given score_mod. It does so in two ways:

  • One can set score_mod.__vec_size__ and have the kernel read that, instead of using the current logic (vec_size = 2 if no aux_tensors are present, otherwise 1)
  • One can set buf.__assumed_align__ and buf.__leading_dim__ for any aux_tensors, allowing vectorized loads in the score_mod when set.
    These options are not exposed in the API; they must be set specific to the given score_mod and aux_tensors, and are thus a "power user" feature.

For a kv bias load score_mod, we see up to 2.9x speedup:

bias.__assumed_align__ == None
### headdim = 128, causal = True, seqlen_q = 8192, seqlen = 8192, batch_size = 1, nheads = 16, varlen = True ###
FA Python fwd with vec_size 1: 0.905ms, 303.7 TFLOPS
FA Python fwd with vec_size 2: 0.789ms, 348.3 TFLOPS
FA Python fwd with vec_size 4: 0.515ms, 533.9 TFLOPS
FA Python fwd with vec_size 8: 0.422ms, 651.0 TFLOPS
FA Python fwd with vec_size 16: 0.414ms, 663.4 TFLOPS
FA Python fwd with vec_size 32: 0.403ms, 682.3 TFLOPS
FA Python fwd with vec_size 64: 0.415ms, 661.8 TFLOPS
FA Python fwd with vec_size 128: 0.387ms, 709.6 TFLOPS

bias.__assumed_align__ == 4
### headdim = 128, causal = True, seqlen_q = 8192, seqlen = 8192, batch_size = 1, nheads = 16, varlen = True ###
FA Python fwd with vec_size 1: 0.903ms, 304.5 TFLOPS
FA Python fwd with vec_size 2: 0.752ms, 365.6 TFLOPS
FA Python fwd with vec_size 4: 0.458ms, 599.5 TFLOPS
FA Python fwd with vec_size 8: 0.366ms, 751.7 TFLOPS
FA Python fwd with vec_size 16: 0.338ms, 814.4 TFLOPS
FA Python fwd with vec_size 32: 0.330ms, 832.4 TFLOPS
FA Python fwd with vec_size 64: 0.321ms, 855.4 TFLOPS
FA Python fwd with vec_size 128: 0.336ms, 818.4 TFLOPS

bias.__assumed_align__ == 8
### headdim = 128, causal = True, seqlen_q = 8192, seqlen = 8192, batch_size = 1, nheads = 16, varlen = True ###
FA Python fwd with vec_size 1: 0.904ms, 304.1 TFLOPS
FA Python fwd with vec_size 2: 0.749ms, 366.8 TFLOPS
FA Python fwd with vec_size 4: 0.462ms, 594.7 TFLOPS
FA Python fwd with vec_size 8: 0.351ms, 783.2 TFLOPS
FA Python fwd with vec_size 16: 0.328ms, 838.3 TFLOPS
FA Python fwd with vec_size 32: 0.332ms, 827.1 TFLOPS
FA Python fwd with vec_size 64: 0.331ms, 830.1 TFLOPS
FA Python fwd with vec_size 128: 0.912ms, 301.3 TFLOPS

bias.__assumed_align__ == 16
### headdim = 128, causal = True, seqlen_q = 8192, seqlen = 8192, batch_size = 1, nheads = 16, varlen = True ###
FA Python fwd with vec_size 1: 0.904ms, 304.1 TFLOPS
FA Python fwd with vec_size 2: 0.749ms, 366.8 TFLOPS
FA Python fwd with vec_size 4: 0.462ms, 594.7 TFLOPS
FA Python fwd with vec_size 8: 0.351ms, 783.2 TFLOPS
FA Python fwd with vec_size 16: 0.328ms, 838.3 TFLOPS
FA Python fwd with vec_size 32: 0.332ms, 827.1 TFLOPS
FA Python fwd with vec_size 64: 0.331ms, 830.1 TFLOPS
FA Python fwd with vec_size 128: 0.912ms, 301.3 TFLOPS

Tests check bitwise equality with unvectorized versions of score mods.
Screenshot 2026-02-05 at 11 40 26 AM
Screenshot 2026-02-05 at 11 40 32 AM
Screenshot 2026-01-30 at 5 22 34 PM
Screenshot 2026-01-30 at 5 29 26 PM

Of course, there is added complexity in defining score_mods to be performant, but it's strictly contained to within the score_mod definition (plus the 3 attributes mentioned above).

Still TODO, reserved for later PRs:

  • Work into the backward pass
  • Vectorize mask_mod application

cc: @drisspg @v0i0

@drisspg drisspg self-requested a review February 5, 2026 17:01
)

if aux_tensors is not None:
aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUPER DUPER nit; aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) if aux_tensors else None

all real estate feels quite precious in this file

custom score_mod callables.
"""
assumed_align: int = getattr(t, "__assumed_align__", None)
leading_dim: int = getattr(t, "__leading_dim__", None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the future we it would be nice to find these programmatically instead of users facing (potentially)

* cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634)
)
diff0 = q_idx[0] - kv_idx[0]
abs_diff = cute.make_rmem_tensor(kv_idx.shape, diff0.dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we write a note somewhere that vec_width for fwd score-mod is always encoded in kv_idx shape?

batch_bias = aux_tensors[0]
dtype = batch_bias.element_type
b_idx0 = b_idx[0]
bias_frag = cute.make_rmem_tensor(1, dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and to triple check is this is actually not vectorized right?

bias_frag = cute.make_rmem_tensor(1, dtype)
bias_frag[0] = batch_bias[b_idx0]
bias_val = (bias_frag.load()).to(cutlass.Float32)
return tSrS_ssa + bias_val
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or maybe it is and this + is doing broadcasting? if so should we also have some doc on this pattern for aux_tensor vectorization?

Comment thread flash_attn/cute/utils.py Outdated
if hasattr(func, "__cute_hash__"):
return func.__cute_hash__

# __vec_size__ is attr of @cute.jitted mod
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I think that since set_hash is True when we change the vecsize we are going to early return from line 40 right and not actually produce a new kernel, can you check in your tests with the for loops if this is the case? or we are producing new python funcs that dont have cute_hash set

Comment thread tests/cute/test_score_mod.py Outdated
@drisspg drisspg merged commit c4d8b06 into Dao-AILab:main Feb 11, 2026
5t4r1i9ht pushed a commit to 5t4r1i9ht/flash-attention that referenced this pull request Mar 15, 2026
* clean up and add more vectorized tests

* remove commented out change

* fix typo

* add aux tensor alignment to compile key

* add varlen score mod vec tests

* uncomment test configs

* sm90 fwd

* update hash callable

* format hash callable

* shorten vec size tests
ussoewwin pushed a commit to ussoewwin/flash-attention that referenced this pull request May 13, 2026
* clean up and add more vectorized tests

* remove commented out change

* fix typo

* add aux tensor alignment to compile key

* add varlen score mod vec tests

* uncomment test configs

* sm90 fwd

* update hash callable

* format hash callable

* shorten vec size tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants