[Cute,Flex,Fwd] Allow vectorized score_mod definitions#2236
Conversation
| ) | ||
|
|
||
| if aux_tensors is not None: | ||
| aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) |
There was a problem hiding this comment.
SUPER DUPER nit; aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) if aux_tensors else None
all real estate feels quite precious in this file
| custom score_mod callables. | ||
| """ | ||
| assumed_align: int = getattr(t, "__assumed_align__", None) | ||
| leading_dim: int = getattr(t, "__leading_dim__", None) |
There was a problem hiding this comment.
I think in the future we it would be nice to find these programmatically instead of users facing (potentially)
| * cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634) | ||
| ) | ||
| diff0 = q_idx[0] - kv_idx[0] | ||
| abs_diff = cute.make_rmem_tensor(kv_idx.shape, diff0.dtype) |
There was a problem hiding this comment.
should we write a note somewhere that vec_width for fwd score-mod is always encoded in kv_idx shape?
| batch_bias = aux_tensors[0] | ||
| dtype = batch_bias.element_type | ||
| b_idx0 = b_idx[0] | ||
| bias_frag = cute.make_rmem_tensor(1, dtype) |
There was a problem hiding this comment.
and to triple check is this is actually not vectorized right?
| bias_frag = cute.make_rmem_tensor(1, dtype) | ||
| bias_frag[0] = batch_bias[b_idx0] | ||
| bias_val = (bias_frag.load()).to(cutlass.Float32) | ||
| return tSrS_ssa + bias_val |
There was a problem hiding this comment.
or maybe it is and this + is doing broadcasting? if so should we also have some doc on this pattern for aux_tensor vectorization?
| if hasattr(func, "__cute_hash__"): | ||
| return func.__cute_hash__ | ||
|
|
||
| # __vec_size__ is attr of @cute.jitted mod |
There was a problem hiding this comment.
hmm I think that since set_hash is True when we change the vecsize we are going to early return from line 40 right and not actually produce a new kernel, can you check in your tests with the for loops if this is the case? or we are producing new python funcs that dont have cute_hash set
fed436a to
04fca59
Compare
* clean up and add more vectorized tests * remove commented out change * fix typo * add aux tensor alignment to compile key * add varlen score mod vec tests * uncomment test configs * sm90 fwd * update hash callable * format hash callable * shorten vec size tests
* clean up and add more vectorized tests * remove commented out change * fix typo * add aux tensor alignment to compile key * add varlen score mod vec tests * uncomment test configs * sm90 fwd * update hash callable * format hash callable * shorten vec size tests
[Reupload of #2215 with Sm90] This PR is a
score_mod"power user" update that allows the user to specify vectorization for a givenscore_mod. It does so in two ways:score_mod.__vec_size__and have the kernel read that, instead of using the current logic (vec_size = 2 if noaux_tensorsare present, otherwise 1)buf.__assumed_align__andbuf.__leading_dim__for anyaux_tensors, allowing vectorized loads in thescore_modwhen set.These options are not exposed in the API; they must be set specific to the given
score_modandaux_tensors, and are thus a "power user" feature.For a kv bias load
score_mod, we see up to 2.9x speedup:Tests check bitwise equality with unvectorized versions of score mods.




Of course, there is added complexity in defining
score_mods to be performant, but it's strictly contained to within thescore_moddefinition (plus the 3 attributes mentioned above).Still TODO, reserved for later PRs:
mask_modapplicationcc: @drisspg @v0i0