Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ training/data
# ck modules
csrc/composable_kernel
csrc/cutlass
.analysis
234 changes: 149 additions & 85 deletions flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,19 @@ def get_autotune_configs():
# Here is the I/O shape:
# Out: (batch, nhead_q, max_seqlens_q, headDim)
# DO: (batch, nhead_q, max_seqlens_q, headDim)
# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at
# fwd_prefill.py line 607
# Delta: (batch, nheads_q, max_seqlens_q)
@triton.autotune(
configs=preprocess_autotune_configs,
key=preprocess_autotune_keys,
use_cuda_graph=True,
)
@triton.jit
def _bwd_preprocess(
O, DO, # noqa: E741
O,
DO, # noqa: E741
Delta,
stride_ob, stride_oh, stride_om, stride_od,
stride_dob, stride_doh, stride_dom, stride_dod,
stride_delta_b, stride_delta_h, stride_delta_m,
stride_descale_do_z,
cu_seqlens_q, max_seqlen_q,
Expand All @@ -125,8 +126,6 @@ def _bwd_preprocess(
bid = tl.program_id(1)
hid = tl.program_id(2)
# Handle varlen
q_start = 0
seqlen_q = max_seqlen_q
if IS_VARLEN:
q_start = tl.load(cu_seqlens_q + bid)
q_end = tl.load(cu_seqlens_q + bid + 1)
Expand All @@ -138,32 +137,41 @@ def _bwd_preprocess(
# Compute offsets
offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK)
offs_d = tl.arange(0, HEAD_DIM)
# Offset O/DO by batch, head and q_start
O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741
DO += bid * stride_ob + hid * stride_oh + q_start * stride_om
# pointer offsets for O & DO
off_o = ( bid * stride_ob
+ hid * stride_oh
+ q_start * stride_om
+ offs_m[:, None] * stride_om
+ offs_d[None, :] * stride_od) # noqa: E741
off_do = (bid * stride_dob
+ hid * stride_doh
+ q_start * stride_dom
+ offs_m[:, None] * stride_dom
+ offs_d[None, :] * stride_dod)

# create masks
mask_m = offs_m < seqlen_q
mask_md = mask_m[:, None]
PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM)
if PADDED_HEAD:
mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM
# compute pointers
offs_do = offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
out_ptrs = O + offs_do
do_ptrs = DO + offs_do
# load
o = tl.load(out_ptrs, mask=mask_md, other=0.0)
do = tl.load(do_ptrs, mask=mask_md, other=0.0)
o = tl.load(O + off_o, mask=mask_md, other=0.0)
do = tl.load(DO + off_do, mask=mask_md, other=0.0)
# compute and write-back to delta
if IS_FP8:
descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hid)
off_descale_do = bid * stride_descale_do_z + hid
descale_do = tl.load(Descale_do + off_descale_do)

# NOTE: do is in the fp8 range and o is not in fp8
delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1)
else:
delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1)
delta_offset = Delta + bid * stride_delta_b + hid * stride_delta_h + q_start * stride_delta_m
tl.store(delta_offset + offs_m * stride_delta_m, delta, mask=mask_m)
off_delta = (bid * stride_delta_b
+ hid * stride_delta_h
+ q_start * stride_delta_m
+ offs_m * stride_delta_m)
tl.store(Delta + off_delta , delta, mask=mask_m)


# The main inner-loop logic for computing dK and dV.
Expand Down Expand Up @@ -1063,8 +1071,9 @@ def is_contiguous(x, name):
print(f"{name} is not contiguous")
return x.contiguous()


OLD_LSE = os.environ.get('OLD_LSE', '0').lower() in ('1', 'true', 'yes')
OLD_LSE: bool = False
DEBUG_TRITON: bool = False
DEBUG_TRITON_DETAIL: bool = False

def attention_prefill_backward_triton_split_fused_no_atomics_impl(
do: torch.Tensor,
Expand Down Expand Up @@ -1098,25 +1107,118 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl(
descale_dk: Optional[torch.Tensor],
descale_dv: Optional[torch.Tensor],
):
# debug
DEBUG_TRITON: bool = False
DEBUG_TRITON_DETAIL: bool = False

# do = is_contiguous(do, "do")
# q = is_contiguous(q, "q")
# k = is_contiguous(k, "k")
# v = is_contiguous(v, "v")
# o = is_contiguous(o, "o")
# softmax_lse = is_contiguous(softmax_lse, "softmax_lse")
# dq = is_contiguous(dq, "dq")
# dk = is_contiguous(dk, "dk")
# dv = is_contiguous(dv, "dv")
# get params, strides and shape
IS_VARLEN = layout == "thd"
use_dropout = (dropout_p > 0.0)

# common assertions
assert 0.0 <= dropout_p <= 1.0, f"dropout_p must be between 0 and 1, got {dropout_p}"
assert q.device == k.device == v.device == o.device == do.device == softmax_lse.device, \
f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}, do={do.device}, softmax_lse={softmax_lse.device}"
assert q.dtype == k.dtype == v.dtype == do.dtype, "q, k, v, do must have the same dtype"
current_device = torch.cuda.current_device()
assert q.is_cuda and q.device.index == current_device, f"Device mismatch: Kernel will launch on cuda:{current_device}, but tensors are on {q.device}"

# get shapes and strides
if IS_VARLEN:
# shape
total_seqlen_q, nheads_q, head_size_q = q.shape
total_seqlen_k, nheads_k, head_size_k = k.shape
total_seqlen_v, nheads_v, head_size_v = v.shape
nheads_lse, total_seqlen_lse = softmax_lse.shape

# assert shapes
assert total_seqlen_lse == total_seqlen_q, f"softmax_lse seqlen {total_seqlen_lse} != q seqlen {total_seqlen_q}"
assert cu_seqlens_q is not None, "cu_seqlens_q must be provided for varlen layout"
assert cu_seqlens_k is not None, "cu_seqlens_k must be provided for varlen layout"
assert max_seqlen_q is not None, "max_seqlen_q must be provided for varlen layout"
assert max_seqlen_k is not None, "max_seqlen_k must be provided for varlen layout"

# assert head dimensions
assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}"
assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}"
assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA"
assert nheads_lse == nheads_q, f"softmax_lse heads {nheads_lse} != q heads {nheads_q}"

# assert output shapes
assert o.shape == (total_seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_q)}"
assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}"
assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}"
assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}"
assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}"

# assert cu_seqlens
assert cu_seqlens_q.dtype == torch.int32, f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}"
assert cu_seqlens_k.dtype == torch.int32, f"cu_seqlens_k must be int32, got {cu_seqlens_k.dtype}"
assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0"
assert cu_seqlens_k[0] == 0, "cu_seqlens_k must start with 0"
assert cu_seqlens_q[-1] == total_seqlen_q, f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} != total_seqlen_q {total_seqlen_q}"
assert cu_seqlens_k[-1] == total_seqlen_k, f"cu_seqlens_k[-1] {cu_seqlens_k[-1]} != total_seqlen_k {total_seqlen_k}"

# set vars
batch = len(cu_seqlens_q) - 1
head_size = head_size_q

# strides
stride_qb, stride_qm, stride_qh, stride_qd = 0, q.stride(0), q.stride(1), q.stride(2)
stride_kb, stride_kn, stride_kh, stride_kd = 0, k.stride(0), k.stride(1), k.stride(2)
stride_vb, stride_vn, stride_vh, stride_vd = 0, v.stride(0), v.stride(1), v.stride(2)
stride_ob, stride_om, stride_oh, stride_od = 0, o.stride(0), o.stride(1), o.stride(2)
stride_dqb, stride_dqm, stride_dqh, stride_dqd = 0, dq.stride(0), dq.stride(1), dq.stride(2)
stride_dkb, stride_dkn, stride_dkh, stride_dkd = 0, dk.stride(0), dk.stride(1), dk.stride(2)
stride_dvb, stride_dvn, stride_dvh, stride_dvd = 0, dv.stride(0), dv.stride(1), dv.stride(2)
stride_dob, stride_dom, stride_doh, stride_dod = 0, do.stride(0), do.stride(1), do.stride(2)
stride_lse_b, stride_lse_h, stride_lse_m = (0, softmax_lse.stride(0), softmax_lse.stride(1))
else:
# shapes
batch_q, seqlen_q, nheads_q, head_size_q = q.shape
batch_k, seqlen_k, nheads_k, head_size_k = k.shape
batch_v, seqlen_v, nheads_v, head_size_v = v.shape
batch_lse, nheads_lse, seqlen_lse = softmax_lse.shape

# assert batch dimensions
assert batch_q == batch_k == batch_v, f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}"

# assert head dimensions
assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}"
assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}"
assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA"

# assert sequence lengths
assert seqlen_k == seqlen_v, f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}"

# assert output shapes
assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected"
assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}"
assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}"
assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}"
assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}"

# assert softmax_lse shape
assert softmax_lse.shape == (batch_q, nheads_q, seqlen_q), f"softmax_lse shape {softmax_lse.shape} != expected"

# set vars
batch = batch_q
head_size = head_size_q
max_seqlen_q = seqlen_q
max_seqlen_k = seqlen_k

# strides
stride_qb, stride_qm, stride_qh, stride_qd = q.stride()
stride_kb, stride_kn, stride_kh, stride_kd = k.stride()
stride_vb, stride_vn, stride_vh, stride_vd = v.stride()
stride_ob, stride_om, stride_oh, stride_od = o.stride()
stride_dqb, stride_dqm, stride_dqh, stride_dqd = dq.stride()
stride_dkb, stride_dkn, stride_dkh, stride_dkd = dk.stride()
stride_dvb, stride_dvn, stride_dvh, stride_dvd = dv.stride()
stride_dob, stride_dom, stride_doh, stride_dod = do.stride()
stride_lse_b, stride_lse_h, stride_lse_m = softmax_lse.stride()

# fp8 setup - moved after all assertions
IS_FP8 = is_fp8(q)
if IS_FP8:
FP8_MAX = torch.finfo(q.dtype).max
# assert that the main inputs are fp8
assert is_fp8(do) and is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: do.dtype={do.dtype}, q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8."
# we already asserted that do, q, k, v all have the same dtype, so no need to check each one
if is_fp8(o):
FP8_OUTPUT = True
assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o."
Expand All @@ -1136,45 +1238,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl(
FP8_OUTPUT = False
stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = stride_descale_do_z = None


# get params, strides and shape
IS_VARLEN = layout == "thd"
use_dropout = (dropout_p > 0.0)

# get shapes and strides
if IS_VARLEN:
# shape
_, nheads_q, head_size = q.shape
_, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
max_seqlen_q_final = max_seqlen_q
max_seqlen_k_final = max_seqlen_k

# strides
stride_qb, stride_qh, stride_qm, stride_qd = 0, q.stride(1), q.stride(0), q.stride(2)
stride_kb, stride_kh, stride_kn, stride_kd = 0, k.stride(1), k.stride(0), k.stride(2)
stride_vb, stride_vh, stride_vn, stride_vd = 0, v.stride(1), v.stride(0), v.stride(2)
stride_ob, stride_oh, stride_om, stride_od = 0, o.stride(1), o.stride(0), o.stride(2)
stride_dqb, stride_dqh, stride_dqm, stride_dqd = 0, dq.stride(1), dq.stride(0), dq.stride(2)
stride_dkb, stride_dkh, stride_dkn, stride_dkd = 0, dk.stride(1), dk.stride(0), dk.stride(2)
stride_dvb, stride_dvh, stride_dvn, stride_dvd = 0, dv.stride(1), dv.stride(0), dv.stride(2)
stride_dob, stride_doh, stride_dom, stride_dod = 0, do.stride(1), do.stride(0), do.stride(2)
stride_lse_b, stride_lse_h, stride_lse_m = (0, softmax_lse.stride(0), softmax_lse.stride(1))
else:
# shapes
batch, max_seqlen_q_final, nheads_q, head_size = q.shape
_, max_seqlen_k_final, nheads_k, _ = k.shape

# strides
stride_qb, stride_qh, stride_qm, stride_qd = q.stride(0), q.stride(2), q.stride(1), q.stride(3)
stride_kb, stride_kh, stride_kn, stride_kd = k.stride(0), k.stride(2), k.stride(1), k.stride(3)
stride_vb, stride_vh, stride_vn, stride_vd = v.stride(0), v.stride(2), v.stride(1), v.stride(3)
stride_ob, stride_oh, stride_om, stride_od = o.stride(0), o.stride(2), o.stride(1), o.stride(3)
stride_dqb, stride_dqh, stride_dqm, stride_dqd = dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3)
stride_dkb, stride_dkh, stride_dkn, stride_dkd = dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3)
stride_dvb, stride_dvh, stride_dvn, stride_dvd = dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3)
stride_dob, stride_doh, stride_dom, stride_dod = do.stride(0), do.stride(2), do.stride(1), do.stride(3)
stride_lse_b, stride_lse_h, stride_lse_m = softmax_lse.stride()
# alibi setup
use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0))

# get closest power of 2 over or equal to 32.
Expand All @@ -1193,28 +1257,28 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl(
else:
if IS_VARLEN:
# interface expects the varlen sequence dims to rounded like this. Not sure why.
batch_size = cu_seqlens_q.numel() - 1
total_q, num_heads, _ = q.shape
total_q_rounded = total_q + 128 * batch_size
total_q_rounded = total_q + 128 * batch
delta_padded = torch.zeros((nheads_q, total_q_rounded), device=q.device, dtype=torch.float32)
delta = delta_padded[:, :total_q]
stride_delta_b, stride_delta_h, stride_delta_m = 0, delta.stride(0), delta.stride(1)
else:
# the interface expects the sequence dimension to be rounded to 128
max_seqlen_q_rounded = round_multiple(max_seqlen_q_final, 128)
max_seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
delta_padded = torch.zeros((batch, nheads_q, max_seqlen_q_rounded),
device=softmax_lse.device, dtype=torch.float32)
delta = delta_padded[:, :, :max_seqlen_q_final]
device=q.device, dtype=torch.float32)
delta = delta_padded[:, :, :max_seqlen_q]
stride_delta_b, stride_delta_h, stride_delta_m = delta.stride()

pre_grid = lambda META: (triton.cdiv(max_seqlen_q_final, META['PRE_BLOCK']), batch, nheads_q)
pre_grid = lambda META: (triton.cdiv(max_seqlen_q, META['PRE_BLOCK']), batch, nheads_q)
_bwd_preprocess[pre_grid](
o, do,
delta,
stride_ob, stride_oh, stride_om, stride_od,
stride_dob, stride_doh, stride_dom, stride_dod,
stride_delta_b, stride_delta_h, stride_delta_m,
stride_descale_do_z,
cu_seqlens_q, max_seqlen_q_final,
cu_seqlens_q, max_seqlen_q,
descale_do,
HEAD_DIM=HEAD_DIM,
ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM,
Expand All @@ -1232,7 +1296,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl(
(0, 0 , 0 , 0)
if use_dropout:
dropout_mask = torch.zeros(
(batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final),
(batch, nheads_q, max_seqlen_q, max_seqlen_k),
device=q.device,
dtype=torch.float32
)
Expand All @@ -1241,7 +1305,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl(
if not IS_VARLEN:
dropout_mask = create_dropout_mask(
dropout_p,
(batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final),
(batch, nheads_q, max_seqlen_q, max_seqlen_k),
seed = philox_seed
)
else:
Expand All @@ -1252,7 +1316,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl(
stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \
dropout_mask.stride()

seqlen = max(max_seqlen_q_final, max_seqlen_k_final)
seqlen = max(max_seqlen_q, max_seqlen_k)
grid = lambda META: (nheads_k, (seqlen + META['BLOCK_N1'] - 1) // META['BLOCK_N1'], batch, )
if causal:
if DEBUG_TRITON: print(f"bwd_kernel: grid = {grid}" ) # noqa: E701
Expand All @@ -1273,7 +1337,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl(
stride_az, stride_ah,
nheads_q, nheads_k,
cu_seqlens_q, cu_seqlens_k,
max_seqlen_q_final, max_seqlen_k_final,
max_seqlen_q, max_seqlen_k,
dropout_mask, dropout_p, philox_seed, philox_offset,
alibi_slopes,
descale_q, descale_k, descale_v, descale_do,
Expand Down Expand Up @@ -1307,7 +1371,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl(
stride_az, stride_ah,
nheads_q, nheads_k,
cu_seqlens_q, cu_seqlens_k,
max_seqlen_q_final, max_seqlen_k_final,
max_seqlen_q, max_seqlen_k,
dropout_mask, dropout_p, philox_seed, philox_offset,
alibi_slopes,
descale_q, descale_k, descale_v, descale_do,
Expand Down
Loading
Loading