Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions flash_attn/cute/flash_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__(
AtomLayoutNdKV: int = 8,
AtomLayoutMdQ: int = 1,
V_in_regs: bool = False,
score_mod: cutlass.Constexpr | None = None,
score_mod_bwd: cutlass.Constexpr | None = None,
):
"""Initializes the configuration for a flash attention v2 kernel.

Expand Down Expand Up @@ -90,6 +92,8 @@ def __init__(
self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_mma_warps and SdP_swapAB and not dKV_swapAB
self.V_in_regs = V_in_regs
self.share_QV_smem = V_in_regs
self.score_mod = score_mod
self.score_mod_bwd = score_mod_bwd

@staticmethod
def can_implement(
Expand Down Expand Up @@ -377,7 +381,6 @@ def __call__(
mCuSeqlensK: Optional[cute.Tensor] = None,
mSeqUsedQ: Optional[cute.Tensor] = None,
mSeqUsedK: Optional[cute.Tensor] = None,
softcap: Float32 | float | None = None,
window_size_left: Int32 | int | None = None,
window_size_right: Int32 | int | None = None,
mdQ_semaphore: Optional[cute.Tensor] = None,
Expand Down Expand Up @@ -430,7 +433,7 @@ def __call__(
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)

softmax_scale_log2 = softmax_scale * math.log2(math.e)
softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod)
self.kernel(
mQ,
mK,
Expand Down Expand Up @@ -773,6 +776,7 @@ def kernel(
smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params,
load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum,
m_block_max=m_block_max,
softmax_scale=softmax_scale,
softmax_scale_log2=softmax_scale_log2,
)

Expand Down Expand Up @@ -861,6 +865,7 @@ def compute_one_m_block(
load_Q_LSE: Callable,
load_dO_dPsum: Callable,
m_block_max: cutlass.Int32,
softmax_scale: cutlass.Float32,
softmax_scale_log2: cutlass.Float32,
mask_fn: Optional[Callable] = None,
):
Expand Down Expand Up @@ -890,13 +895,24 @@ def load_dO_next():
smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV,
swap_AB=self.SdP_swapAB,
)
acc_S_pre = cute.make_fragment_like(acc_S)
acc_S_pre.store(acc_S.load())
tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0])
cute.autovec_copy(
smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE
)
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)
acc_S_pre_mn = layout_utils.reshape_acc_to_mn(acc_S_pre)
if cutlass.const_expr(self.score_mod is not None):
for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True):
acc_S_mn[r, None].store(
self.score_mod(
acc_S_mn[r, None].load() * softmax_scale,
0, 0, 0, 0, None, [],
Comment thread
CaesarG marked this conversation as resolved.
)
)
if cutlass.const_expr(mask_fn is not None):
mask_fn(acc_S, m_block=m_block)
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)
bidx = 0
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE)
Expand Down Expand Up @@ -926,7 +942,14 @@ def load_dO_next():
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)
assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum)
for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True):
acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r]))
grad_val = acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r])
if cutlass.const_expr(self.score_mod_bwd is not None):
grad_val = self.score_mod_bwd(
grad_val,
acc_S_pre_mn[r, None].load() * softmax_scale,
0, 0, 0, 0, None, [],
)
acc_dP_mn[r, None].store(grad_val)
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)
rP = cute.make_fragment_like(acc_S, self.dtype)
rP.store(acc_S.load().to(self.dtype))
Expand Down
1 change: 0 additions & 1 deletion flash_attn/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,6 @@ def __call__(
mCuSeqlensK: Optional[cute.Tensor] = None,
mSeqUsedQ: Optional[cute.Tensor] = None,
mSeqUsedK: Optional[cute.Tensor] = None,
softcap: Float32 | float | None = None,
window_size_left: Int32 | int | None = None,
window_size_right: Int32 | int | None = None,
mdQ_semaphore: Optional[cute.Tensor] = None,
Expand Down
1 change: 0 additions & 1 deletion flash_attn/cute/flash_bwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,6 @@ def __call__(
mCuSeqlensK: Optional[cute.Tensor] = None,
mSeqUsedQ: Optional[cute.Tensor] = None,
mSeqUsedK: Optional[cute.Tensor] = None,
softcap: Float32 | float | None = None,
window_size_left: Int32 | int | None = None,
window_size_right: Int32 | int | None = None,
mdQ_semaphore: Optional[cute.Tensor] = None,
Expand Down
38 changes: 36 additions & 2 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import utils
from flash_attn.cute.mask import AttentionMask
from flash_attn.cute.softmax import Softmax
from flash_attn.cute.softmax import Softmax, apply_score_mod_inner
from flash_attn.cute.seqlen_info import SeqlenInfoQK
from flash_attn.cute.block_info import BlockInfo
from flash_attn.cute.pack_gqa import PackGQA
Expand Down Expand Up @@ -1145,8 +1145,8 @@ def load_V_next():
m_block,
acc_S,
n_block,
seqlen,
softmax_scale=softmax.softmax_scale,
seqlen=seqlen,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
)
Expand Down Expand Up @@ -1185,6 +1185,40 @@ def load_K_next():
)
# if const_expr(self.num_stages > 1):
# load_K_next()
@cute.jit
def apply_score_mod(
self,
thr_mma_qk,
batch_idx,
head_idx,
m_block,
acc_S,
n_block,
softmax_scale,
seqlen,
aux_tensors: Optional[list] = None,
fastdiv_mods=None,
):
# Prepare index tensor
cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS)
tScS = thr_mma_qk.partition_C(cS)

apply_score_mod_inner(
acc_S,
tScS,
self.score_mod,
batch_idx,
head_idx,
softmax_scale,
self.vec_size,
self.qk_acc_dtype,
aux_tensors,
fastdiv_mods,
seqlen_info=seqlen,
constant_q_idx=None,
qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
)


# SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility
Expand Down
36 changes: 25 additions & 11 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,13 +543,16 @@ def _flash_attn_fwd(
and (tile_m % qhead_per_kvhead == 0 or not pack_gqa)
)

# hash score and mask mods for compile cache
score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False

if softcap is not None:
assert score_mod is None, "softcap and score_mod cannot be used together"
score_mod = utils.create_softcap_scoremod(softcap)
elif score_mod is not None:
if arch // 10 == 8:
raise NotImplementedError("Custom user-provided score_mod is not supported on SM8x architectures.")

# hash score and mask mods for compile cache
score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False

is_varlen = (
cu_seqlens_q is not None
Expand Down Expand Up @@ -1170,12 +1173,20 @@ def _flash_attn_bwd(
pack_gqa = qhead_per_kvhead > 1
# pack_gqa backward not yet supported in bwd
pack_gqa = False
if score_mod is not None:

if softcap != 0.0:
assert score_mod is None and score_mod_bwd is None, (
"softcap and score_mod/score_mod_bwd cannot be used together"
)
score_mod = utils.create_softcap_scoremod(softcap)
score_mod_bwd = utils.create_softcap_scoremod_bwd(softcap)
elif score_mod is not None:
assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided"
assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)"
assert cu_seqlens_q is None and cu_seqlens_k is None, (
"varlen + score_mod not supported in bwd yet"
)
if arch // 10 == 8:
raise NotImplementedError("Custom user-provided score_mod is not supported on SM8x architectures.")

device = q.device
out_torch_dtype = q.dtype
Expand Down Expand Up @@ -1321,7 +1332,6 @@ def _flash_attn_bwd(
causal,
window_size_left is not None,
window_size_right is not None,
softcap != 0.0,
m_block_size,
n_block_size,
num_threads,
Expand Down Expand Up @@ -1351,6 +1361,9 @@ def _flash_attn_bwd(
get_broadcast_dims(k),
get_broadcast_dims(v),
get_broadcast_dims(dout),
# Prevent TVM stride poisoning when only one block is present.
(seqlen_q_rounded // m_block_size == 1),
(seqlen_k_rounded // n_block_size == 1),
)
else:
compile_key = (
Expand All @@ -1362,7 +1375,6 @@ def _flash_attn_bwd(
causal,
window_size_left is not None,
window_size_right is not None,
softcap != 0.0,
m_block_size,
n_block_size,
num_threads,
Expand All @@ -1384,6 +1396,9 @@ def _flash_attn_bwd(
get_broadcast_dims(k),
get_broadcast_dims(v),
get_broadcast_dims(dout),
# Prevent TVM stride poisoning when only one block is present.
(seqlen_q_rounded // m_block_size == 1),
(seqlen_k_rounded // n_block_size == 1),
)
if compile_key not in _flash_attn_bwd.compile_cache:
q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [
Expand Down Expand Up @@ -1426,6 +1441,8 @@ def _flash_attn_bwd(
AtomLayoutNdKV,
AtomLayoutMdQ,
V_in_regs=V_in_regs,
score_mod=score_mod,
score_mod_bwd=score_mod_bwd,
)
elif arch // 10 == 9:
fa_bwd_obj = FlashAttentionBackwardSm90(
Expand Down Expand Up @@ -1497,7 +1514,6 @@ def _flash_attn_bwd(
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
None, # softcap - not yet supported in backward
window_size_left,
window_size_right,
dQ_semaphore_tensor,
Expand All @@ -1524,7 +1540,6 @@ def _flash_attn_bwd(
cu_seqlens_k,
seqused_q,
seqused_k,
None, # softcap - not yet supported in backward
window_size_left,
window_size_right,
dQ_semaphore,
Expand Down Expand Up @@ -1723,7 +1738,6 @@ def forward(
@staticmethod
def backward(ctx, dout, dlse):
q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
assert ctx.softcap == 0.0
if not ctx.return_lse:
dlse = None
if dout is None:
Expand Down
22 changes: 17 additions & 5 deletions flash_attn/cute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,28 @@ def hash_callable(


def create_softcap_scoremod(softcap_val):
inv_softcap = 1.0 / softcap_val

@cute.jit
def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
scores = acc_S_SSA * inv_softcap
return scores * cute.math.tanh(scores, fastmath=True)
def scoremod_premask_fn(
acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
scores = acc_S_SSA / softcap_val
return softcap_val * cute.math.tanh(scores, fastmath=True)

return scoremod_premask_fn


def create_softcap_scoremod_bwd(softcap_val):
@cute.jit
def scoremod_bwd_fn(
grad_out_SSA, score_SSA, batch_idx, head_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
scores = score_SSA / softcap_val
tanh_scores = cute.math.tanh(scores, fastmath=True)
return grad_out_SSA * (1.0 - tanh_scores * tanh_scores)

return scoremod_bwd_fn


LOG2_E = math.log2(math.e)


Expand Down
33 changes: 29 additions & 4 deletions tests/cute/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os
import random
import re
import gc
from functools import wraps

import pytest
import torch
Expand All @@ -28,8 +30,28 @@
from flash_attn.cute.interface import (
flash_attn_func,
flash_attn_varlen_func,
_flash_attn_fwd,
_flash_attn_bwd,
)

def retry_on_oom(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except torch.OutOfMemoryError as e:
if "out of memory" in str(e).lower():
if hasattr(_flash_attn_fwd, "compile_cache"):
_flash_attn_fwd.compile_cache.clear()
if hasattr(_flash_attn_bwd, "compile_cache"):
_flash_attn_bwd.compile_cache.clear()
gc.collect()
torch.cuda.empty_cache()
return func(*args, **kwargs)
else:
raise
return wrapper

# torch FakeTensorMode would enable fast cutedsl kernel compilation without allocating the actual GPU memory or running the kernel
# When operating fake tensors, we cannot perform data-dependent operations (e.g., `tensor.max()`).
USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1
Expand All @@ -50,8 +72,8 @@
@pytest.mark.parametrize("has_qv", [False])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [False])
# @pytest.mark.parametrize("softcap", [0.0, 15.0])
@pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("softcap", [0.0, 15.0])
# @pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("local_enum", [0, 1, 2, 3])
# @pytest.mark.parametrize("local_enum", [0])
@pytest.mark.parametrize("causal", [False, True])
Expand Down Expand Up @@ -96,6 +118,7 @@
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@retry_on_oom
@maybe_fake_tensor_mode(USE_FAKE_TENSOR)
def test_flash_attn_output(
seqlen_q,
Expand Down Expand Up @@ -388,8 +411,8 @@ def test_flash_attn_output(
@pytest.mark.parametrize("has_qv", [False])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [False])
# @pytest.mark.parametrize("softcap", [0.0, 15.0])
@pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("softcap", [0.0, 15.0])
# @pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("local_enum", [0, 1, 2, 3])
# @pytest.mark.parametrize("local_enum", [0])
@pytest.mark.parametrize("causal", [False, True])
Expand Down Expand Up @@ -449,6 +472,7 @@ def test_flash_attn_output(
(False, True),
],
)
@retry_on_oom
@maybe_fake_tensor_mode(USE_FAKE_TENSOR)
def test_flash_attn_varlen_output(
seqlen_q,
Expand Down Expand Up @@ -927,6 +951,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@retry_on_oom
@maybe_fake_tensor_mode(USE_FAKE_TENSOR)
def test_flash_attn_kvcache(
seqlen_q,
Expand Down
Loading
Loading