diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 824abdda139..eeb7615b1d3 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -46,6 +46,8 @@ def __init__( AtomLayoutNdKV: int = 8, AtomLayoutMdQ: int = 1, V_in_regs: bool = False, + score_mod: cutlass.Constexpr | None = None, + score_mod_bwd: cutlass.Constexpr | None = None, ): """Initializes the configuration for a flash attention v2 kernel. @@ -90,6 +92,8 @@ def __init__( self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_mma_warps and SdP_swapAB and not dKV_swapAB self.V_in_regs = V_in_regs self.share_QV_smem = V_in_regs + self.score_mod = score_mod + self.score_mod_bwd = score_mod_bwd @staticmethod def can_implement( @@ -377,7 +381,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, @@ -430,7 +433,7 @@ def __call__( tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) - softmax_scale_log2 = softmax_scale * math.log2(math.e) + softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod) self.kernel( mQ, mK, @@ -773,6 +776,7 @@ def kernel( smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params, load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum, m_block_max=m_block_max, + softmax_scale=softmax_scale, softmax_scale_log2=softmax_scale_log2, ) @@ -861,6 +865,7 @@ def compute_one_m_block( load_Q_LSE: Callable, load_dO_dPsum: Callable, m_block_max: cutlass.Int32, + softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, mask_fn: Optional[Callable] = None, ): @@ -890,13 +895,24 @@ def load_dO_next(): smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, swap_AB=self.SdP_swapAB, ) + acc_S_pre = cute.make_fragment_like(acc_S) + acc_S_pre.store(acc_S.load()) tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0]) cute.autovec_copy( smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE ) + acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) + acc_S_pre_mn = layout_utils.reshape_acc_to_mn(acc_S_pre) + if cutlass.const_expr(self.score_mod is not None): + for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True): + acc_S_mn[r, None].store( + self.score_mod( + acc_S_mn[r, None].load() * softmax_scale, + 0, 0, 0, 0, None, [], + ) + ) if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) - acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) bidx = 0 # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE) @@ -926,7 +942,14 @@ def load_dO_next(): # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum) for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True): - acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r])) + grad_val = acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r]) + if cutlass.const_expr(self.score_mod_bwd is not None): + grad_val = self.score_mod_bwd( + grad_val, + acc_S_pre_mn[r, None].load() * softmax_scale, + 0, 0, 0, 0, None, [], + ) + acc_dP_mn[r, None].store(grad_val) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) rP = cute.make_fragment_like(acc_S, self.dtype) rP.store(acc_S.load().to(self.dtype)) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index e06cd811fc6..4b4083eda9e 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -456,7 +456,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index f724b5a11e3..c9a690d1e90 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -350,7 +350,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 4d47fab109f..d1a43cfd247 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -27,7 +27,7 @@ from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.pack_gqa import PackGQA @@ -1145,8 +1145,8 @@ def load_V_next(): m_block, acc_S, n_block, - seqlen, softmax_scale=softmax.softmax_scale, + seqlen=seqlen, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) @@ -1185,6 +1185,40 @@ def load_K_next(): ) # if const_expr(self.num_stages > 1): # load_K_next() + @cute.jit + def apply_score_mod( + self, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + acc_S, + n_block, + softmax_scale, + seqlen, + aux_tensors: Optional[list] = None, + fastdiv_mods=None, + ): + # Prepare index tensor + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS) + tScS = thr_mma_qk.partition_C(cS) + + apply_score_mod_inner( + acc_S, + tScS, + self.score_mod, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info=seqlen, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index ef624677f01..d23536a90db 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -543,13 +543,16 @@ def _flash_attn_fwd( and (tile_m % qhead_per_kvhead == 0 or not pack_gqa) ) - # hash score and mask mods for compile cache - score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False - mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False - if softcap is not None: assert score_mod is None, "softcap and score_mod cannot be used together" score_mod = utils.create_softcap_scoremod(softcap) + elif score_mod is not None: + if arch // 10 == 8: + raise NotImplementedError("Custom user-provided score_mod is not supported on SM8x architectures.") + + # hash score and mask mods for compile cache + score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False + mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False is_varlen = ( cu_seqlens_q is not None @@ -1170,12 +1173,20 @@ def _flash_attn_bwd( pack_gqa = qhead_per_kvhead > 1 # pack_gqa backward not yet supported in bwd pack_gqa = False - if score_mod is not None: + + if softcap != 0.0: + assert score_mod is None and score_mod_bwd is None, ( + "softcap and score_mod/score_mod_bwd cannot be used together" + ) + score_mod = utils.create_softcap_scoremod(softcap) + score_mod_bwd = utils.create_softcap_scoremod_bwd(softcap) + elif score_mod is not None: assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided" - assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)" assert cu_seqlens_q is None and cu_seqlens_k is None, ( "varlen + score_mod not supported in bwd yet" ) + if arch // 10 == 8: + raise NotImplementedError("Custom user-provided score_mod is not supported on SM8x architectures.") device = q.device out_torch_dtype = q.dtype @@ -1321,7 +1332,6 @@ def _flash_attn_bwd( causal, window_size_left is not None, window_size_right is not None, - softcap != 0.0, m_block_size, n_block_size, num_threads, @@ -1351,6 +1361,9 @@ def _flash_attn_bwd( get_broadcast_dims(k), get_broadcast_dims(v), get_broadcast_dims(dout), + # Prevent TVM stride poisoning when only one block is present. + (seqlen_q_rounded // m_block_size == 1), + (seqlen_k_rounded // n_block_size == 1), ) else: compile_key = ( @@ -1362,7 +1375,6 @@ def _flash_attn_bwd( causal, window_size_left is not None, window_size_right is not None, - softcap != 0.0, m_block_size, n_block_size, num_threads, @@ -1384,6 +1396,9 @@ def _flash_attn_bwd( get_broadcast_dims(k), get_broadcast_dims(v), get_broadcast_dims(dout), + # Prevent TVM stride poisoning when only one block is present. + (seqlen_q_rounded // m_block_size == 1), + (seqlen_k_rounded // n_block_size == 1), ) if compile_key not in _flash_attn_bwd.compile_cache: q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ @@ -1426,6 +1441,8 @@ def _flash_attn_bwd( AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs=V_in_regs, + score_mod=score_mod, + score_mod_bwd=score_mod_bwd, ) elif arch // 10 == 9: fa_bwd_obj = FlashAttentionBackwardSm90( @@ -1497,7 +1514,6 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, - None, # softcap - not yet supported in backward window_size_left, window_size_right, dQ_semaphore_tensor, @@ -1524,7 +1540,6 @@ def _flash_attn_bwd( cu_seqlens_k, seqused_q, seqused_k, - None, # softcap - not yet supported in backward window_size_left, window_size_right, dQ_semaphore, @@ -1723,7 +1738,6 @@ def forward( @staticmethod def backward(ctx, dout, dlse): q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors - assert ctx.softcap == 0.0 if not ctx.return_lse: dlse = None if dout is None: diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 31186618569..76579c81cc7 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -126,16 +126,28 @@ def hash_callable( def create_softcap_scoremod(softcap_val): - inv_softcap = 1.0 / softcap_val - @cute.jit - def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors): - scores = acc_S_SSA * inv_softcap - return scores * cute.math.tanh(scores, fastmath=True) + def scoremod_premask_fn( + acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, seqlen_info, aux_tensors + ): + scores = acc_S_SSA / softcap_val + return softcap_val * cute.math.tanh(scores, fastmath=True) return scoremod_premask_fn +def create_softcap_scoremod_bwd(softcap_val): + @cute.jit + def scoremod_bwd_fn( + grad_out_SSA, score_SSA, batch_idx, head_idx, q_idx, kv_idx, seqlen_info, aux_tensors + ): + scores = score_SSA / softcap_val + tanh_scores = cute.math.tanh(scores, fastmath=True) + return grad_out_SSA * (1.0 - tanh_scores * tanh_scores) + + return scoremod_bwd_fn + + LOG2_E = math.log2(math.e) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 69e6308fb60..5f2c7732956 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -5,6 +5,8 @@ import os import random import re +import gc +from functools import wraps import pytest import torch @@ -28,8 +30,28 @@ from flash_attn.cute.interface import ( flash_attn_func, flash_attn_varlen_func, + _flash_attn_fwd, + _flash_attn_bwd, ) +def retry_on_oom(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except torch.OutOfMemoryError as e: + if "out of memory" in str(e).lower(): + if hasattr(_flash_attn_fwd, "compile_cache"): + _flash_attn_fwd.compile_cache.clear() + if hasattr(_flash_attn_bwd, "compile_cache"): + _flash_attn_bwd.compile_cache.clear() + gc.collect() + torch.cuda.empty_cache() + return func(*args, **kwargs) + else: + raise + return wrapper + # torch FakeTensorMode would enable fast cutedsl kernel compilation without allocating the actual GPU memory or running the kernel # When operating fake tensors, we cannot perform data-dependent operations (e.g., `tensor.max()`). USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 @@ -50,8 +72,8 @@ @pytest.mark.parametrize("has_qv", [False]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("softcap", [0.0, 15.0]) -@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 15.0]) +# @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) # @pytest.mark.parametrize("local_enum", [0]) @pytest.mark.parametrize("causal", [False, True]) @@ -96,6 +118,7 @@ ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +@retry_on_oom @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_output( seqlen_q, @@ -388,8 +411,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("has_qv", [False]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("softcap", [0.0, 15.0]) -@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 15.0]) +# @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) # @pytest.mark.parametrize("local_enum", [0]) @pytest.mark.parametrize("causal", [False, True]) @@ -449,6 +472,7 @@ def test_flash_attn_output( (False, True), ], ) +@retry_on_oom @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_varlen_output( seqlen_q, @@ -927,6 +951,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@retry_on_oom @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_kvcache( seqlen_q, diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py index a9b8799f4c1..18295e01843 100644 --- a/tests/cute/test_flash_attn_race_condition.py +++ b/tests/cute/test_flash_attn_race_condition.py @@ -43,8 +43,8 @@ @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [True]) -# @pytest.mark.parametrize("softcap", [0.0, 15.0]) -@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 15.0]) +# @pytest.mark.parametrize("softcap", [0.0]) # @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) @pytest.mark.parametrize("local_enum", [0, 1]) @pytest.mark.parametrize("causal", [False, True]) @@ -356,8 +356,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [True]) -# @pytest.mark.parametrize("softcap", [0.0, 15.0]) -@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("softcap", [0.0, 15.0]) +# @pytest.mark.parametrize("softcap", [0.0]) # @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) @pytest.mark.parametrize("local_enum", [0, 1]) @pytest.mark.parametrize("causal", [False, True])