Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flash_attn/cute/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,9 @@ def get_trip_start_count_via_block_info(
padded_offset_k=Int32(0),
seqlen_q=seqlen_q,
seqlen_k=seqlen_k,
m_block_offset=Int32(0),
block_idx_offset=Int32(0),
num_n_blocks=cute.ceil_div(seqlen_k, tile_shape[1]),
has_cu_seqlens_q=False,
has_cu_seqlens_k=False,
has_seqused_q=False,
Expand Down Expand Up @@ -911,6 +914,9 @@ def get_trip_mask_bounds_via_block_info(
padded_offset_k=Int32(0),
seqlen_q=seqlen_q,
seqlen_k=seqlen_k,
m_block_offset=Int32(0),
block_idx_offset=Int32(0),
num_n_blocks=cute.ceil_div(seqlen_k, tile_shape[1]),
has_cu_seqlens_q=False,
has_cu_seqlens_k=False,
has_seqused_q=False,
Expand Down
61 changes: 48 additions & 13 deletions flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,8 @@ def __call__(
cute.assume(Q.stride[1], divby=64),
Q.stride[4],
(
(Q.shape[4], Q.shape[4] * Q.shape[3]),
(
0
if varlen
else cute.assume(Q.shape[1] * Q.shape[4] * h_r * h_k, divby=64)
),
(Q.stride[3], Q.stride[2]),
0 if cumulative_s_q is not None else cute.assume(Q.stride[0], divby=64),
),
),
),
Expand All @@ -263,8 +259,8 @@ def __call__(
cute.assume(K.stride[1], divby=64),
K.stride[4],
(
(0, K.shape[4]),
(0 if varlen else cute.assume(K.shape[1] * K.shape[4] * 1 * h_k, divby=64)),
(0, K.stride[2]),
0 if cumulative_s_k is not None else cute.assume(K.stride[0], divby=64),
),
),
),
Expand All @@ -278,8 +274,8 @@ def __call__(
cute.assume(V.stride[1], divby=64),
V.stride[4],
(
(0, V.shape[4]),
(0 if varlen else cute.assume(V.shape[1] * V.shape[4] * 1 * h_k, divby=64)),
(0, V.stride[2]),
0 if cumulative_s_k is not None else cute.assume(V.stride[0], divby=64),
),
),
),
Expand All @@ -296,10 +292,49 @@ def __call__(
),
),
)
dK = cute.make_tensor(dK.iterator, K.layout)
dV = cute.make_tensor(dV.iterator, V.layout)
dK = cute.make_tensor(
dK.iterator,
cute.make_layout(
(dK.shape[1], dK.shape[4], hb),
stride=(
cute.assume(dK.stride[1], divby=64),
dK.stride[4],
(
(0, dK.stride[2]),
0 if cumulative_s_k is not None else cute.assume(dK.stride[0], divby=64),
),
),
),
)
dV = cute.make_tensor(
dV.iterator,
cute.make_layout(
(dV.shape[1], dV.shape[4], hb),
stride=(
cute.assume(dV.stride[1], divby=64),
dV.stride[4],
(
(0, dV.stride[2]),
0 if cumulative_s_k is not None else cute.assume(dV.stride[0], divby=64),
),
),
),
)
# (s, d, ((h_r, h_k), b))
dO = cute.make_tensor(dO.iterator, Q.layout)
dO = cute.make_tensor(
dO.iterator,
cute.make_layout(
(dO.shape[1], dO.shape[4], hb),
stride=(
cute.assume(dO.stride[1], divby=64),
dO.stride[4],
(
(dO.stride[3], dO.stride[2]),
0 if cumulative_s_q is not None else cute.assume(dO.stride[0], divby=64),
),
),
),
)

# (s, d, ((h_r, h_k), b)) -> (d, s, ((h_r, h_k), b))
dOT = cute.make_tensor(
Expand Down
57 changes: 48 additions & 9 deletions flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def __call__(
s_q64 = Int64(s_q)
s_k64 = Int64(s_k)
s_lse64 = Int64(s_lse)
d64 = cute.assume(Int64(d), divby=128)
h_r64 = Int64(h_r)
h_k64 = Int64(h_k)
b64 = Int64(b)
Expand All @@ -196,39 +195,72 @@ def __call__(
# `cuseqlen_*` offsets stays within the tensor domain.
s_q_total = q_tensor.shape[1] if cum_seqlen_q is not None else s_q64
s_k_total = k_tensor.shape[1] if cum_seqlen_k is not None else s_k64
stride_b_qo = h_r64 * h_k64 * s_q64 * d64 if cum_seqlen_q is None else 0
stride_b_kv = h_k64 * s_k64 * d64 if cum_seqlen_k is None else 0
b_lse = b64 if cum_seqlen_q is None else 1
stride_b_lse = h_r64 * h_k64 * s_lse64 if cum_seqlen_q is None else 0

# (s, d, ((h_r, h_k), b))
q_layout = cute.make_layout(
(s_q_total, d, ((h_r, h_k), b)),
stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)),
stride=(
cute.assume(q_tensor.stride[1], divby=64),
q_tensor.stride[4],
(
(q_tensor.stride[3], q_tensor.stride[2]),
0 if cum_seqlen_q is not None else cute.assume(q_tensor.stride[0], divby=64),
),
),
)
q = cute.make_tensor(q_tensor.iterator, q_layout)
# (s, d, ((h_r, h_k), b))
do_layout = cute.make_layout(
(s_q_total, d, ((h_r, h_k), b)),
stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)),
stride=(
cute.assume(do_tensor.stride[1], divby=64),
do_tensor.stride[4],
(
(do_tensor.stride[3], do_tensor.stride[2]),
0 if cum_seqlen_q is not None else cute.assume(do_tensor.stride[0], divby=64),
),
),
)
do = cute.make_tensor(do_tensor.iterator, do_layout)
# (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast
k_layout = cute.make_layout(
(s_k_total, d, ((h_r, h_k), b)),
stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)),
stride=(
cute.assume(k_tensor.stride[1], divby=64),
k_tensor.stride[4],
(
(0, k_tensor.stride[2]),
0 if cum_seqlen_k is not None else cute.assume(k_tensor.stride[0], divby=64),
),
),
)
k = cute.make_tensor(k_tensor.iterator, k_layout)
# (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast
kt_layout = cute.make_layout(
(d, s_k_total, ((h_r, h_k), b)),
stride=(1, d64 * h_k64, ((0, d64), stride_b_kv)),
stride=(
k_tensor.stride[4],
cute.assume(k_tensor.stride[1], divby=64),
(
(0, k_tensor.stride[2]),
0 if cum_seqlen_k is not None else cute.assume(k_tensor.stride[0], divby=64),
),
),
)
kt = cute.make_tensor(k_tensor.iterator, kt_layout)
# (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast
v_layout = cute.make_layout(
(s_k_total, d, ((h_r, h_k), b)),
stride=(d64 * h_k64, 1, ((0, d64), stride_b_kv)),
stride=(
cute.assume(v_tensor.stride[1], divby=64),
v_tensor.stride[4],
(
(0, v_tensor.stride[2]),
0 if cum_seqlen_k is not None else cute.assume(v_tensor.stride[0], divby=64),
),
),
)
v = cute.make_tensor(v_tensor.iterator, v_layout)
# (s, ((h_r, h_k), b))
Expand All @@ -242,7 +274,14 @@ def __call__(
# (s, d, ((h_r, h_k), b))
dq_layout = cute.make_layout(
(s_q_total, d, ((h_r, h_k), b)),
stride=(d64 * h_r64 * h_k64, 1, ((d64, d64 * h_r64), stride_b_qo)),
stride=(
cute.assume(dq_tensor.stride[1], divby=64),
dq_tensor.stride[4],
(
(dq_tensor.stride[3], dq_tensor.stride[2]),
0 if cum_seqlen_q is not None else cute.assume(dq_tensor.stride[0], divby=64),
),
),
)
dq = cute.make_tensor(dq_tensor.iterator, dq_layout)

Expand Down