diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index ed4154edbf3..0b0488963ba 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -447,7 +447,7 @@ def __call__( if const_expr(not self.dKV_postprocess): layout_dKV_transpose = KV_layout_transpose else: - layout_dKV_transpose = LSE_dPsum_dQaccum_transpose + layout_dKV_transpose = [2, 1, 0] if const_expr(mCuSeqlensK is None) else [1, 0] mdK, mdV = [utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)] # (s, h, n, b) --> (h, s, n, b) or (t, h, n) -> (h, t, b) dO_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensQ is None) else [1, 0, 2] @@ -2373,7 +2373,7 @@ def compute_loop( # When total_m_block_cnt == 0 for block sparsity, no Q tiles contribute to this KV tile if const_expr(not self.dKV_postprocess): should_zero_dKV = False - if const_expr(self.is_local or seqlen.has_cu_seqlens_q): + if const_expr(self.is_local or self.is_varlen_q): should_zero_dKV = m_block_min >= m_block_max if const_expr(self.use_block_sparsity): # For block sparsity, zero when no m_blocks contribute to this n_block diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index c902a17bb6e..9d5b25b25e0 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -1130,9 +1130,7 @@ def _flash_attn_bwd( AtomLayoutMdQ, dQ_swapAB, cu_seqlens_q is None, - cu_seqlens_k is None, seqused_q is None, - seqused_k is None, ) if compile_key_post not in _flash_attn_bwd.compile_cache_post: dq_accum_tensor = to_cute_tensor(dq_accum) @@ -1174,9 +1172,7 @@ def _flash_attn_bwd( num_threads, AtomLayoutNdKV, dKV_swapAB, - cu_seqlens_q is None, cu_seqlens_k is None, - seqused_q is None, seqused_k is None, ) if compile_key_post not in _flash_attn_bwd.compile_cache_post: @@ -1217,9 +1213,7 @@ def _flash_attn_bwd( num_threads, AtomLayoutNdKV, dKV_swapAB, - cu_seqlens_q is None, cu_seqlens_k is None, - seqused_q is None, seqused_k is None, ) if compile_key_post not in _flash_attn_bwd.compile_cache_post: diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 471cd35711e..1c2088dd28a 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -425,6 +425,15 @@ def test_flash_attn_output( (True, True), ], ) +@pytest.mark.parametrize( + "unpad_q, unpad_kv", + [ + (True, True), + (False, False), + (True, False), + (False, True), + ], +) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, @@ -441,6 +450,8 @@ def test_flash_attn_varlen_output( varlen_mode, zero_lengths_q, zero_lengths_k, + unpad_q, + unpad_kv, ): local = local_enum > 0 if local and causal: @@ -588,8 +599,14 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask, ) - print("cu_seqlens_q = ", cu_seqlens_q) - print("cu_seqlens_k = ", cu_seqlens_k) + if unpad_q: + print("cu_seqlens_q = ", cu_seqlens_q) + else: + print("seqused_q = ", seqused_q) + if unpad_kv: + print("cu_seqlens_k = ", cu_seqlens_k) + else: + print("seqused_k = ", seqused_k) q_unpad, k_unpad, v_unpad = [ x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) ] @@ -649,15 +666,15 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): if IS_SM90 and num_splits > 1: continue out_unpad, lse = flash_attn_varlen_func( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, + q_unpad if unpad_q else q, + k_unpad if unpad_kv else k, + v_unpad if unpad_kv else v, + cu_seqlens_q=cu_seqlens_q if unpad_q else None, + cu_seqlens_k=cu_seqlens_k if unpad_kv else None, max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k, - # seqused_q=seqused_q, - # seqused_k=seqused_k, + seqused_q=seqused_q if not unpad_q else None, + seqused_k=seqused_k if not unpad_kv else None, causal=causal, # qv=qv_unpad, # q_descale=q_descale, @@ -670,7 +687,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): pack_gqa=pack_gqa, deterministic=deterministic, ) - out = output_pad_fn(out_unpad) + out = output_pad_fn(out_unpad) if unpad_q else out_unpad if query_unused_mask is not None: out.masked_fill_(q_zero_masking, 0.0) print(f"Output max diff: {(out - out_ref).abs().max().item()}") @@ -720,21 +737,32 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # 0, # sm_margin # ) dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( - out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad + out_unpad, + ( + q_unpad if unpad_q else q, + k_unpad if unpad_kv else k, + v_unpad if unpad_kv else v, + ), + g_unpad ) - dq = dq_pad_fn(dq_unpad) - dk = dk_pad_fn(dk_unpad) - dv = dk_pad_fn(dv_unpad) + dq = dq_pad_fn(dq_unpad) if unpad_q else dq_unpad + dk = dk_pad_fn(dk_unpad) if unpad_kv else dk_unpad + dv = dk_pad_fn(dv_unpad) if unpad_kv else dv_unpad if key_unused_mask is not None: k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") dk.masked_fill_(k_zero_masking, 0.0) dv.masked_fill_(k_zero_masking, 0.0) if query_unused_mask is not None: dq.masked_fill_(q_zero_masking, 0.0) + if not unpad_kv: + dk.masked_fill_(rearrange(~key_padding_mask, "b s -> b s 1 1"), 0.0) + dv.masked_fill_(rearrange(~key_padding_mask, "b s -> b s 1 1"), 0.0) + if not unpad_q: + dq.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") # assert (softmax_d - do_o).abs().max().item() <= 1e-5 # assert dq_accum.abs().max().item() == 0.0 - g = output_pad_fn(g_unpad) + g = output_pad_fn(g_unpad) if unpad_q else g_unpad # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) @@ -762,6 +790,24 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + if VERBOSE: + diff_dq = (dq - dq_ref).abs() + max_idx = diff_dq.argmax() + coords = torch.unravel_index(max_idx, diff_dq.shape) + print(f"dQ max diff: {diff_dq.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dQ={dq[coords].item()}, dQ_ref={dq_ref[coords].item()}") + + diff_dk = (dk - dk_ref).abs() + max_idx = diff_dk.argmax() + coords = torch.unravel_index(max_idx, diff_dk.shape) + print(f"dK max diff: {diff_dk.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dK={dk[coords].item()}, dK_ref={dk_ref[coords].item()}") + + diff_dv = (dv - dv_ref).abs() + max_idx = diff_dv.argmax() + coords = torch.unravel_index(max_idx, diff_dv.shape) + print(f"dV max diff: {diff_dv.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dV={dv[coords].item()}, dV_ref={dv_ref[coords].item()}") # breakpoint() dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4