Skip to content

[Cute, Flex, Fwd, Sm100] Allow vectorized score_mod definitions#2215

Closed
reubenconducts wants to merge 7 commits into
Dao-AILab:mainfrom
reubenconducts:rstern/vec-mod
Closed

[Cute, Flex, Fwd, Sm100] Allow vectorized score_mod definitions#2215
reubenconducts wants to merge 7 commits into
Dao-AILab:mainfrom
reubenconducts:rstern/vec-mod

Conversation

@reubenconducts
Copy link
Copy Markdown
Contributor

@reubenconducts reubenconducts commented Jan 28, 2026

This PR is a score_mod "power user" update that allows the user to specify vectorization for a given score_mod. It does so in two ways:

  • One can set score_mod.__vec_size__ and have the kernel read that, instead of using the current logic (vec_size = 2 if no aux_tensors are present, otherwise 1)
  • One can set buf.__assumed_align__ and buf.__leading_dim__ for any aux_tensors, allowing vectorized loads in the score_mod when set.
    These options are not exposed in the API; they must be set specific to the given score_mod and aux_tensors, and are thus a "power user" feature.

For a kv bias load score_mod, we see up to 2.9x speedup:

bias.__assumed_align__ == None
### headdim = 128, causal = True, seqlen_q = 8192, seqlen = 8192, batch_size = 1, nheads = 16, varlen = True ###
FA Python fwd with vec_size 1: 0.905ms, 303.7 TFLOPS
FA Python fwd with vec_size 2: 0.789ms, 348.3 TFLOPS
FA Python fwd with vec_size 4: 0.515ms, 533.9 TFLOPS
FA Python fwd with vec_size 8: 0.422ms, 651.0 TFLOPS
FA Python fwd with vec_size 16: 0.414ms, 663.4 TFLOPS
FA Python fwd with vec_size 32: 0.403ms, 682.3 TFLOPS
FA Python fwd with vec_size 64: 0.415ms, 661.8 TFLOPS
FA Python fwd with vec_size 128: 0.387ms, 709.6 TFLOPS

bias.__assumed_align__ == 4
### headdim = 128, causal = True, seqlen_q = 8192, seqlen = 8192, batch_size = 1, nheads = 16, varlen = True ###
FA Python fwd with vec_size 1: 0.903ms, 304.5 TFLOPS
FA Python fwd with vec_size 2: 0.752ms, 365.6 TFLOPS
FA Python fwd with vec_size 4: 0.458ms, 599.5 TFLOPS
FA Python fwd with vec_size 8: 0.366ms, 751.7 TFLOPS
FA Python fwd with vec_size 16: 0.338ms, 814.4 TFLOPS
FA Python fwd with vec_size 32: 0.330ms, 832.4 TFLOPS
FA Python fwd with vec_size 64: 0.321ms, 855.4 TFLOPS
FA Python fwd with vec_size 128: 0.336ms, 818.4 TFLOPS

bias.__assumed_align__ == 8
### headdim = 128, causal = True, seqlen_q = 8192, seqlen = 8192, batch_size = 1, nheads = 16, varlen = True ###
FA Python fwd with vec_size 1: 0.904ms, 304.1 TFLOPS
FA Python fwd with vec_size 2: 0.749ms, 366.8 TFLOPS
FA Python fwd with vec_size 4: 0.462ms, 594.7 TFLOPS
FA Python fwd with vec_size 8: 0.351ms, 783.2 TFLOPS
FA Python fwd with vec_size 16: 0.328ms, 838.3 TFLOPS
FA Python fwd with vec_size 32: 0.332ms, 827.1 TFLOPS
FA Python fwd with vec_size 64: 0.331ms, 830.1 TFLOPS
FA Python fwd with vec_size 128: 0.912ms, 301.3 TFLOPS

bias.__assumed_align__ == 16
### headdim = 128, causal = True, seqlen_q = 8192, seqlen = 8192, batch_size = 1, nheads = 16, varlen = True ###
FA Python fwd with vec_size 1: 0.904ms, 304.1 TFLOPS
FA Python fwd with vec_size 2: 0.749ms, 366.8 TFLOPS
FA Python fwd with vec_size 4: 0.462ms, 594.7 TFLOPS
FA Python fwd with vec_size 8: 0.351ms, 783.2 TFLOPS
FA Python fwd with vec_size 16: 0.328ms, 838.3 TFLOPS
FA Python fwd with vec_size 32: 0.332ms, 827.1 TFLOPS
FA Python fwd with vec_size 64: 0.331ms, 830.1 TFLOPS
FA Python fwd with vec_size 128: 0.912ms, 301.3 TFLOPS

Passing tests, comparing bitwise equality between many vectorized and unvectorized score mods:
Screenshot 2026-01-30 at 5 22 34 PM
Screenshot 2026-01-30 at 5 29 26 PM

Of course, there is added complexity in defining score_mods to be performant, but it's strictly contained to within the score_mod definition (plus the 3 attributes mentioned above).

Still TODO, reserved for later PRs:

  • Work into the backward pass
  • Vectorize mask_mod application

cc: @drisspg @v0i0

@reubenconducts reubenconducts changed the title [Cute, Flex, Fwd, Sm100] Allow vectorized score_mods [Cute, Flex, Fwd, Sm100] Allow vectorized score_mod definitions Jan 28, 2026
Comment thread flash_attn/cute/cute_dsl_utils.py
Comment thread flash_attn/cute/interface.py
Comment thread tests/cute/score_mod_definitions.py Outdated
Comment thread tests/cute/score_mod_definitions.py
Copy link
Copy Markdown
Collaborator

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some comments but I like the direction

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants