-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[Cute,Flex,Fwd] Allow vectorized score_mod definitions #2236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fce9870
cf14ef1
85566ec
6d9ef84
01d228a
782f9bd
11447a4
ab02dd3
04fca59
eaa4a49
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,7 +41,7 @@ | |
|
|
||
|
|
||
| from flash_attn.cute import utils | ||
| from flash_attn.cute.cute_dsl_utils import to_cute_tensor | ||
| from flash_attn.cute.cute_dsl_utils import to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata | ||
| from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 | ||
| from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 | ||
| from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess | ||
|
|
@@ -368,7 +368,11 @@ def _flash_attn_fwd( | |
| block_size=(m_block_size, n_block_size), | ||
| q_stage=q_stage, | ||
| ) | ||
|
|
||
| if aux_tensors is not None: | ||
| aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SUPER DUPER nit; aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) if aux_tensors else None all real estate feels quite precious in this file |
||
| else: | ||
| aux_tensor_metadata = None | ||
|
|
||
| compile_key = ( | ||
| dtype, | ||
| head_dim, | ||
|
|
@@ -379,7 +383,7 @@ def _flash_attn_fwd( | |
| mask_mod_hash, | ||
| use_block_sparsity, | ||
| block_sparse_broadcast_pattern, | ||
| len(aux_tensors) if aux_tensors is not None else 0, | ||
| aux_tensor_metadata, | ||
| lse is None, | ||
| cu_seqlens_q is None, | ||
| cu_seqlens_k is None, | ||
|
|
@@ -432,8 +436,9 @@ def _flash_attn_fwd( | |
| sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) | ||
|
|
||
| cute_aux_tensors = None | ||
| aux_tensor_metadata = None | ||
| if aux_tensors is not None: | ||
| cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors] | ||
| cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] | ||
|
|
||
| if compute_capability == 9: | ||
| assert page_table is None, "paged KV not supported on SM 9.0" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,19 +15,47 @@ def score_mod_identity(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_t | |
| return tSrS_ssa | ||
|
|
||
|
|
||
| @cute.jit | ||
| def score_mod_identity_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): | ||
| return tSrS_ssa | ||
|
|
||
|
|
||
| @cute.jit | ||
| def score_mod_causal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): | ||
| mask = operator.ge(q_idx, kv_idx) | ||
| return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) | ||
|
|
||
|
|
||
| @cute.jit | ||
| def score_mod_causal_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): | ||
| mask = cute.make_rmem_tensor(kv_idx.shape, dtype=cutlass.Boolean) | ||
| kv_idx0 = kv_idx[0] | ||
| q_idx0 = q_idx[0] | ||
| for i in cutlass.range_constexpr(cute.size(mask.shape)): | ||
| mask[i] = q_idx0 >= kv_idx0 + i | ||
| mask_ssa = mask.load() | ||
| return cute.where(mask_ssa, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) | ||
|
|
||
|
|
||
| @cute.jit | ||
| def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): | ||
| diff = q_idx - kv_idx | ||
| abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype) | ||
| return tSrS_ssa + abs_diff.to(cutlass.Float32) | ||
|
|
||
|
|
||
| @cute.jit | ||
| def score_mod_rel_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): | ||
| q_idx0 = q_idx[0] | ||
| kv_idx0 = kv_idx[0] | ||
| diff0 = q_idx0 - kv_idx0 | ||
| abs_diff = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype) | ||
| for i in cutlass.range_constexpr(cute.size(kv_idx.shape)): | ||
| diffi = diff0 - i | ||
| abs_diff[i] = mlir_math.absi(diffi) | ||
| return tSrS_ssa + abs_diff.load().to(cutlass.Float32) | ||
|
|
||
|
|
||
| @cute.jit | ||
| def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): | ||
| diff = q_idx - kv_idx | ||
|
|
@@ -36,10 +64,25 @@ def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, au | |
| return tSrS_ssa + scaled.to(cutlass.Float32) | ||
|
|
||
|
|
||
| @cute.jit | ||
| def score_mod_rel_bias_x2_vectorized( | ||
| tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors | ||
| ): | ||
| q_idx0 = q_idx[0] | ||
| kv_idx0 = kv_idx[0] | ||
| diff0 = q_idx0 - kv_idx0 | ||
| abs_diff_x2 = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype) | ||
| for i in cutlass.range_constexpr(cute.size(kv_idx.shape)): | ||
| diffi = diff0 - i | ||
| abs_diff_x2[i] = mlir_math.absi(diffi) * 2 | ||
| return tSrS_ssa + abs_diff_x2.load().to(cutlass.Float32) | ||
|
|
||
|
|
||
| @cute.jit | ||
| def score_mod_times_two(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): | ||
| return tSrS_ssa * cute.full_like(tSrS_ssa, 2) | ||
|
|
||
| score_mod_times_two_vectorized = score_mod_times_two | ||
|
|
||
| @cute.jit | ||
| def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): | ||
|
|
@@ -53,6 +96,21 @@ def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tens | |
| abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype).to(cutlass.Float32) | ||
| return score - slope * abs_diff | ||
|
|
||
| @cute.jit | ||
| def score_mod_alibi_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): | ||
| score = tSrS_ssa.to(cutlass.Float32) | ||
| slope_exp = (h_idx + cute.full_like(h_idx, 1)) * cute.full_like(h_idx, -8) | ||
| slope = cute.math.exp2( | ||
| slope_exp.to(cutlass.Float32) | ||
| * cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634) | ||
| ) | ||
| diff0 = q_idx[0] - kv_idx[0] | ||
| abs_diff = cute.make_rmem_tensor(kv_idx.shape, diff0.dtype) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we write a note somewhere that vec_width for fwd score-mod is always encoded in kv_idx shape? |
||
| for i in cutlass.range_constexpr(cute.size(abs_diff.shape)): | ||
| diffi = diff0 - i | ||
| abs_diff[i] = mlir_math.absi(diffi) | ||
| return score - slope * abs_diff.load().to(cutlass.Float32) | ||
|
|
||
|
|
||
| @cute.jit | ||
| def score_mod_sliding_window(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): | ||
|
|
@@ -88,6 +146,16 @@ def score_mod_batch_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux | |
| bias_val = (bias_frag.load()).to(cutlass.Float32) | ||
| return tSrS_ssa + bias_val | ||
|
|
||
| @cute.jit | ||
| def score_mod_batch_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): | ||
| batch_bias = aux_tensors[0] | ||
| dtype = batch_bias.element_type | ||
| b_idx0 = b_idx[0] | ||
| bias_frag = cute.make_rmem_tensor(1, dtype) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and to triple check is this is actually not vectorized right? |
||
| bias_frag[0] = batch_bias[b_idx0] | ||
| bias_val = (bias_frag.load()).to(cutlass.Float32) | ||
| return tSrS_ssa + bias_val | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or maybe it is and this + is doing broadcasting? if so should we also have some doc on this pattern for aux_tensor vectorization? |
||
|
|
||
|
|
||
| @cute.jit | ||
| def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): | ||
|
|
@@ -109,6 +177,22 @@ def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, au | |
|
|
||
| return tSrS_ssa + head_val + pos_val | ||
|
|
||
| @cute.jit | ||
| def score_mod_dual_buffer_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): | ||
| head_bias = aux_tensors[0] | ||
| pos_bias = aux_tensors[1] | ||
| dtype = head_bias.element_type | ||
|
|
||
| head_val_frag = cute.make_fragment(1, dtype) | ||
| head_val_frag[0] = head_bias[h_idx[0]] | ||
| head_val = (head_val_frag.load()).to(cutlass.Float32) | ||
|
|
||
| pos_val_frag = cute.make_fragment(1, dtype) | ||
| pos_val_frag[0] = pos_bias[q_idx[0]] | ||
| pos_val = (pos_val_frag.load()).to(cutlass.Float32) | ||
|
|
||
| return tSrS_ssa + head_val + pos_val | ||
|
|
||
|
|
||
| # ============================================================================= | ||
| # Score_mod functions that use global indices | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in the future we it would be nice to find these programmatically instead of users facing (potentially)