diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index ce0a1b6e5e9..8211e01965e 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -384,7 +384,8 @@ def __call__( self._check_type(*(t.element_type if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK))) # Assume all strides are divisible by 128 bits except the last stride - new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + # Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints) + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) if s != 0 else s for s in t.stride[:-1]), t.stride[-1]) mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)] self.varlen_q = (mCuSeqlensQ is not None) self._setup_attributes() diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index a94bdf3c85b..ede18638a73 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -351,8 +351,12 @@ def __call__( ) # Assume all strides are divisible by 128 bits except the last stride + # Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints) new_stride = lambda t: ( - *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + *( + cute.assume(s, divby=128 // t.element_type.width) if s != 0 else s + for s in t.stride[:-1] + ), t.stride[-1], ) mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [ diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index c341d26fbbf..3ba52ce4540 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -659,8 +659,14 @@ def __call__( self._setup_attributes() SharedStorage = self._get_shared_storage_cls() # Assume all strides are divisible by 128 bits except the last stride + # Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints) new_stride = lambda t: ( - *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + *( + cute.assume(s, divby=128 // t.element_type.width) + if s != 0 + else s + for s in t.stride[:-1] + ), t.stride[-1], ) mQ, mK, mV, mO = [ @@ -1296,8 +1302,14 @@ def __call__( ) # Assume all strides are divisible by 128 bits except the last stride + # Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints) new_stride = lambda t: ( - *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + *( + cute.assume(s, divby=128 // t.element_type.width) + if s != 0 + else s + for s in t.stride[:-1] + ), t.stride[-1], ) diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index a4b5bf27107..f830fcb0afb 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -989,5 +989,119 @@ def run_with_block_mask_nheads(block_mask_nheads: int) -> tuple[torch.Tensor, to assert err_no_broadcast_dq < 0.1, f"No-broadcast dQ error too large: {err_no_broadcast_dq:.2e}" +def test_gqa_expand_stride_zero_bug(): + """Test that GQA with expand()-created K/V tensors works correctly. + + This is a regression test for bugs with expand()-created tensors: + + Forward bug: cute.assume() fails when tensor strides are Python int 0 + (from expand()) instead of MLIR values. + Error: AttributeError: 'int' object has no attribute 'type' + + Backward bug: mark_layout_dynamic fails with expanded tensors. + Error: RuntimeError: Expected strides[leading_dim] == 1, but got N. + + Trigger: expand() + transpose() creates stride=0 dimensions (GQA pattern). + """ + torch.manual_seed(42) + + batch_size = 1 + seqlen = 2048 + headdim = 128 + n_heads = 4 + n_kv_heads = 1 + dtype = torch.bfloat16 + device = "cuda" + + q = torch.randn(batch_size, seqlen, n_heads, headdim, device=device, dtype=dtype) + k_orig = torch.randn(batch_size, seqlen, n_kv_heads, headdim, device=device, dtype=dtype) + v_orig = torch.randn(batch_size, seqlen, n_kv_heads, headdim, device=device, dtype=dtype) + + k = k_orig.expand(batch_size, seqlen, n_heads, headdim) + v = v_orig.expand(batch_size, seqlen, n_heads, headdim) + + assert k.stride()[2] == 0, "K should have stride=0 in head dim from expand()" + assert v.stride()[2] == 0, "V should have stride=0 in head dim from expand()" + + out = torch.empty_like(q) + lse = torch.empty(batch_size, n_heads, seqlen, device=device, dtype=torch.float32) + softmax_scale = 1.0 / math.sqrt(headdim) + + out_tuple = _flash_attn_fwd( + q=q, k=k, v=v, out=out, lse=lse, + softmax_scale=softmax_scale, + causal=True, + m_block_size=128, n_block_size=128, + return_lse=True, + ) + out_fwd, lse_fwd = out_tuple[0], out_tuple[1] + + assert not torch.isnan(out_fwd).any(), "Forward output contains NaN" + assert torch.isfinite(out_fwd).all(), "Forward output contains non-finite values" + + tensors_for_ref = {"q": q, "k": k, "v": v} + tensors_fp32 = {"q": q.float(), "k": k.float(), "v": v.float()} + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + out_ref = compute_reference_flex_attn(tensors_for_ref, causal_mask) + out_ref_fp32 = compute_reference_flex_attn(tensors_fp32, causal_mask) + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + pt_error = (out_ref - out_ref_fp32).abs().max().item() + cute_error = (out_fwd - out_ref_fp32).abs().max().item() + + print(f"\nGQA expand stride=0 test:") + print(f" Forward: kernel err={cute_error:.2e}, ref err={pt_error:.2e}, atol={fwd_atol:.2e}") + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"Forward error {cute_error:.2e} exceeds {rtol}x ref error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + grad_out = torch.randn_like(out_fwd) + dq, dk, dv = _flash_attn_bwd( + q=q, k=k, v=v, out=out_fwd, dout=grad_out, lse=lse_fwd, + softmax_scale=softmax_scale, + causal=True, + m_block_size=128, n_block_size=128, + ) + + assert not torch.isnan(dq).any(), "dQ contains NaN" + assert not torch.isnan(dk).any(), "dK contains NaN" + assert not torch.isnan(dv).any(), "dV contains NaN" + + flex_block_mask = create_block_mask( + causal_mask, batch_size, n_heads, seqlen, seqlen, + device=device, BLOCK_SIZE=(128, 128), + ) + _, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out, dtype=torch.float32) + + bwd_rtol = 2 + bwd_atol_floor = 1e-5 + + dq_atol = max(bwd_atol_floor, 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item()) + dk_atol = max(bwd_atol_floor, 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item()) + dv_atol = max(bwd_atol_floor, 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item()) + + _, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out) + + pt_dq_err = (dq_pt - dq_ref.to(dtype)).abs().max().item() + pt_dk_err = (dk_pt - dk_ref.to(dtype)).abs().max().item() + pt_dv_err = (dv_pt - dv_ref.to(dtype)).abs().max().item() + + cute_dq_err = (dq - dq_ref.to(dtype)).abs().max().item() + cute_dk_err = (dk - dk_ref.to(dtype)).abs().max().item() + cute_dv_err = (dv - dv_ref.to(dtype)).abs().max().item() + + print(f" Backward dQ: kernel err={cute_dq_err:.2e}, ref err={pt_dq_err:.2e}, atol={dq_atol:.2e}") + print(f" Backward dK: kernel err={cute_dk_err:.2e}, ref err={pt_dk_err:.2e}, atol={dk_atol:.2e}") + print(f" Backward dV: kernel err={cute_dv_err:.2e}, ref err={pt_dv_err:.2e}, atol={dv_atol:.2e}") + + assert cute_dq_err <= bwd_rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" + assert cute_dk_err <= bwd_rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" + assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"])