Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flash_attn/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def __call__(
if const_expr(not self.dKV_postprocess):
layout_dKV_transpose = KV_layout_transpose
else:
layout_dKV_transpose = LSE_dPsum_dQaccum_transpose
layout_dKV_transpose = [2, 1, 0] if const_expr(mCuSeqlensK is None) else [1, 0]
mdK, mdV = [utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)]
# (s, h, n, b) --> (h, s, n, b) or (t, h, n) -> (h, t, b)
dO_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensQ is None) else [1, 0, 2]
Expand Down Expand Up @@ -2373,7 +2373,7 @@ def compute_loop(
# When total_m_block_cnt == 0 for block sparsity, no Q tiles contribute to this KV tile
if const_expr(not self.dKV_postprocess):
should_zero_dKV = False
if const_expr(self.is_local or seqlen.has_cu_seqlens_q):
if const_expr(self.is_local or self.is_varlen_q):
should_zero_dKV = m_block_min >= m_block_max
if const_expr(self.use_block_sparsity):
# For block sparsity, zero when no m_blocks contribute to this n_block
Expand Down
6 changes: 0 additions & 6 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,9 +1130,7 @@ def _flash_attn_bwd(
AtomLayoutMdQ,
dQ_swapAB,
cu_seqlens_q is None,
cu_seqlens_k is None,
seqused_q is None,
seqused_k is None,
)
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
dq_accum_tensor = to_cute_tensor(dq_accum)
Expand Down Expand Up @@ -1174,9 +1172,7 @@ def _flash_attn_bwd(
num_threads,
AtomLayoutNdKV,
dKV_swapAB,
cu_seqlens_q is None,
cu_seqlens_k is None,
seqused_q is None,
seqused_k is None,
)
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
Expand Down Expand Up @@ -1217,9 +1213,7 @@ def _flash_attn_bwd(
num_threads,
AtomLayoutNdKV,
dKV_swapAB,
cu_seqlens_q is None,
cu_seqlens_k is None,
seqused_q is None,
seqused_k is None,
)
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
Expand Down
76 changes: 61 additions & 15 deletions tests/cute/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,15 @@ def test_flash_attn_output(
(True, True),
],
)
@pytest.mark.parametrize(
"unpad_q, unpad_kv",
[
(True, True),
(False, False),
(True, False),
(False, True),
],
)
def test_flash_attn_varlen_output(
seqlen_q,
seqlen_k,
Expand All @@ -441,6 +450,8 @@ def test_flash_attn_varlen_output(
varlen_mode,
zero_lengths_q,
zero_lengths_k,
unpad_q,
unpad_kv,
):
local = local_enum > 0
if local and causal:
Expand Down Expand Up @@ -588,8 +599,14 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
query_unused_mask=query_unused_mask,
key_unused_mask=key_unused_mask,
)
print("cu_seqlens_q = ", cu_seqlens_q)
print("cu_seqlens_k = ", cu_seqlens_k)
if unpad_q:
print("cu_seqlens_q = ", cu_seqlens_q)
else:
print("seqused_q = ", seqused_q)
if unpad_kv:
print("cu_seqlens_k = ", cu_seqlens_k)
else:
print("seqused_k = ", seqused_k)
q_unpad, k_unpad, v_unpad = [
x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)
]
Expand Down Expand Up @@ -649,15 +666,15 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
if IS_SM90 and num_splits > 1:
continue
out_unpad, lse = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
q_unpad if unpad_q else q,
k_unpad if unpad_kv else k,
v_unpad if unpad_kv else v,
cu_seqlens_q=cu_seqlens_q if unpad_q else None,
cu_seqlens_k=cu_seqlens_k if unpad_kv else None,
max_seqlen_q=seqlen_q,
max_seqlen_k=seqlen_k,
# seqused_q=seqused_q,
# seqused_k=seqused_k,
seqused_q=seqused_q if not unpad_q else None,
seqused_k=seqused_k if not unpad_kv else None,
causal=causal,
# qv=qv_unpad,
# q_descale=q_descale,
Expand All @@ -670,7 +687,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
pack_gqa=pack_gqa,
deterministic=deterministic,
)
out = output_pad_fn(out_unpad)
out = output_pad_fn(out_unpad) if unpad_q else out_unpad
if query_unused_mask is not None:
out.masked_fill_(q_zero_masking, 0.0)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
Expand Down Expand Up @@ -720,21 +737,32 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
# 0, # sm_margin
# )
dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(
out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad
out_unpad,
(
q_unpad if unpad_q else q,
k_unpad if unpad_kv else k,
v_unpad if unpad_kv else v,
),
g_unpad
)
dq = dq_pad_fn(dq_unpad)
dk = dk_pad_fn(dk_unpad)
dv = dk_pad_fn(dv_unpad)
dq = dq_pad_fn(dq_unpad) if unpad_q else dq_unpad
dk = dk_pad_fn(dk_unpad) if unpad_kv else dk_unpad
dv = dk_pad_fn(dv_unpad) if unpad_kv else dv_unpad
if key_unused_mask is not None:
k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
dk.masked_fill_(k_zero_masking, 0.0)
dv.masked_fill_(k_zero_masking, 0.0)
if query_unused_mask is not None:
dq.masked_fill_(q_zero_masking, 0.0)
if not unpad_kv:
dk.masked_fill_(rearrange(~key_padding_mask, "b s -> b s 1 1"), 0.0)
dv.masked_fill_(rearrange(~key_padding_mask, "b s -> b s 1 1"), 0.0)
if not unpad_q:
dq.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
# print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
# assert (softmax_d - do_o).abs().max().item() <= 1e-5
# assert dq_accum.abs().max().item() == 0.0
g = output_pad_fn(g_unpad)
g = output_pad_fn(g_unpad) if unpad_q else g_unpad

# qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float()
# qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
Expand Down Expand Up @@ -762,6 +790,24 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
if VERBOSE:
diff_dq = (dq - dq_ref).abs()
max_idx = diff_dq.argmax()
coords = torch.unravel_index(max_idx, diff_dq.shape)
print(f"dQ max diff: {diff_dq.max().item()}")
print(f" at coordinates {tuple(c.item() for c in coords)}: dQ={dq[coords].item()}, dQ_ref={dq_ref[coords].item()}")

diff_dk = (dk - dk_ref).abs()
max_idx = diff_dk.argmax()
coords = torch.unravel_index(max_idx, diff_dk.shape)
print(f"dK max diff: {diff_dk.max().item()}")
print(f" at coordinates {tuple(c.item() for c in coords)}: dK={dk[coords].item()}, dK_ref={dk_ref[coords].item()}")

diff_dv = (dv - dv_ref).abs()
max_idx = diff_dv.argmax()
coords = torch.unravel_index(max_idx, diff_dv.shape)
print(f"dV max diff: {diff_dv.max().item()}")
print(f" at coordinates {tuple(c.item() for c in coords)}: dV={dv[coords].item()}, dV_ref={dv_ref[coords].item()}")
# breakpoint()
dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (
0 if softcap == 0 else 3e-4
Expand Down