diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 31186618569..5d34a35790b 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -129,7 +129,7 @@ def create_softcap_scoremod(softcap_val): inv_softcap = 1.0 / softcap_val @cute.jit - def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors): + def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, seqlen_info, aux_tensors): scores = acc_S_SSA * inv_softcap return scores * cute.math.tanh(scores, fastmath=True)