diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 625f4b3d14c..6ce6c6d9e98 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -1676,7 +1676,6 @@ def softmax_loop( seqlen=seqlen, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, - mask_fn=partial(mask_fn, mask_seqlen=False), ) if has_work: