diff --git a/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py b/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py index 79265cc87f..ff5b96b77c 100644 --- a/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py +++ b/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py @@ -38,6 +38,10 @@ def _bwd_preprocess( stride_o_h, stride_o_m, stride_o_k, + stride_do_b, + stride_do_h, + stride_do_m, + stride_do_k, stride_delta_b, stride_delta_h, stride_delta_m, @@ -70,14 +74,22 @@ def _bwd_preprocess( offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - # Offset O/DO by batch, head and q_start - offs = ( + # O and DO may have different strides (e.g. BSHD vs SBHD memory layout), + # so address each with its own strides. + offs_o = ( bid * stride_o_b + hid * stride_o_h + q_start * stride_o_m + offs_m[:, None] * stride_o_m + offs_k[None, :] * stride_o_k ) + offs_do = ( + bid * stride_do_b + + hid * stride_do_h + + q_start * stride_do_m + + offs_m[:, None] * stride_do_m + + offs_k[None, :] * stride_do_k + ) # create masks mask_m = offs_m < seqlen_q @@ -87,8 +99,8 @@ def _bwd_preprocess( mask &= offs_k[None, :] < BLOCK_D_MODEL # load [BLOCK_M, BLOCK_D_MODEL_POW2] - o = tl.load(o_ptr + offs, mask=mask, other=0.0) - do = tl.load(do_ptr + offs, mask=mask, other=0.0) + o = tl.load(o_ptr + offs_o, mask=mask, other=0.0) + do = tl.load(do_ptr + offs_do, mask=mask, other=0.0) # compute and write-back to delta if IS_FP8: diff --git a/aiter/ops/triton/_triton_kernels/attention/mha_onekernel_bwd.py b/aiter/ops/triton/_triton_kernels/attention/mha_onekernel_bwd.py index 7563920c38..1ff5a2ae57 100644 --- a/aiter/ops/triton/_triton_kernels/attention/mha_onekernel_bwd.py +++ b/aiter/ops/triton/_triton_kernels/attention/mha_onekernel_bwd.py @@ -43,6 +43,10 @@ def _bwd_preprocess( stride_o_h, stride_o_m, stride_o_k, + stride_do_b, + stride_do_h, + stride_do_m, + stride_do_k, stride_delta_b, stride_delta_h, stride_delta_m, @@ -75,14 +79,22 @@ def _bwd_preprocess( offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - # Offset O/DO by batch, head and q_start - offs = ( + # O and DO may have different strides (e.g. BSHD vs SBHD memory layout), + # so address each with its own strides. + offs_o = ( bid * stride_o_b + hid * stride_o_h + q_start * stride_o_m + offs_m[:, None] * stride_o_m + offs_k[None, :] * stride_o_k ) + offs_do = ( + bid * stride_do_b + + hid * stride_do_h + + q_start * stride_do_m + + offs_m[:, None] * stride_do_m + + offs_k[None, :] * stride_do_k + ) # create masks mask_m = offs_m < seqlen_q @@ -92,8 +104,8 @@ def _bwd_preprocess( mask &= offs_k[None, :] < BLOCK_D_MODEL # load [BLOCK_M, BLOCK_D_MODEL_POW2] - o = tl.load(o_ptr + offs, mask=mask, other=0.0) - do = tl.load(do_ptr + offs, mask=mask, other=0.0) + o = tl.load(o_ptr + offs_o, mask=mask, other=0.0) + do = tl.load(do_ptr + offs_do, mask=mask, other=0.0) # compute and write-back to delta if IS_FP8: diff --git a/aiter/ops/triton/attention/mha_fused_bwd.py b/aiter/ops/triton/attention/mha_fused_bwd.py index 8b9f3f38f7..634d8a0a2c 100644 --- a/aiter/ops/triton/attention/mha_fused_bwd.py +++ b/aiter/ops/triton/attention/mha_fused_bwd.py @@ -181,6 +181,7 @@ def flash_attn_fused_backward( do, delta, *o_strides, + *do_strides, *delta_strides, descale_strides[3], cu_seqlens_q, diff --git a/aiter/ops/triton/attention/mha_onekernel_bwd.py b/aiter/ops/triton/attention/mha_onekernel_bwd.py index 4ef9c24654..1b2a3f7800 100644 --- a/aiter/ops/triton/attention/mha_onekernel_bwd.py +++ b/aiter/ops/triton/attention/mha_onekernel_bwd.py @@ -221,6 +221,7 @@ def flash_attn_onekernel_backward( do, delta, *o_strides, + *do_strides, *delta_strides, descale_strides[3], cu_seqlens_q, diff --git a/op_tests/triton_tests/attention/test_mha.py b/op_tests/triton_tests/attention/test_mha.py index cc25d69559..96495c168e 100644 --- a/op_tests/triton_tests/attention/test_mha.py +++ b/op_tests/triton_tests/attention/test_mha.py @@ -685,6 +685,78 @@ def test_mha_backward( assert_cosine_similarity(tri, ref) +@pytest.mark.parametrize("BATCH", [2, 4]) +@pytest.mark.parametrize("SEQLEN_Q", [256, 512]) +@pytest.mark.parametrize("SEQLEN_K", [256, 512]) +@pytest.mark.parametrize("NUM_Q_HEADS", [32]) +@pytest.mark.parametrize("NUM_K_HEADS", [8]) +@pytest.mark.parametrize("HEAD_SZ", [128]) +@pytest.mark.parametrize("CAUSAL", [False]) +@pytest.mark.parametrize("FUSED", [False, True]) +def test_mha_backward_sbhd_do( + BATCH: int, + SEQLEN_Q: int, + SEQLEN_K: int, + NUM_Q_HEADS: int, + NUM_K_HEADS: int, + HEAD_SZ: int, + CAUSAL: bool, + FUSED: bool, + dtype=torch.float16, +): + """Verify backward correctness when dO has SBHD memory layout (strides differ from O). + + Creates dO as a (seqlen, batch, nheads, headdim) tensor transposed to + (batch, seqlen, nheads, headdim), so its strides are different from the + contiguous BSHD output tensor. This exercises the independent stride + handling for dO in _bwd_preprocess. + """ + torch.cuda.empty_cache() + torch.manual_seed(42) + if FUSED and CAUSAL: + pytest.skip("FUSED+CAUSAL results in NaNs") + + mha_set_use_fused_bwd_kernel(FUSED) + + q = torch.randn(BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ, device="cuda", dtype=dtype) + k = torch.randn(BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ, device="cuda", dtype=dtype) + v = torch.randn(BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ, device="cuda", dtype=dtype) + q.requires_grad = True + k.requires_grad = True + v.requires_grad = True + + # dO in SBHD memory layout: (seqlen, batch, nheads, headdim) viewed as BSHD + do_sbhd = torch.randn( + SEQLEN_Q, BATCH, NUM_Q_HEADS, HEAD_SZ, device="cuda", dtype=dtype + ) + do = do_sbhd.transpose(0, 1) # shape is BSHD but strides are SBHD + assert not do.is_contiguous(), "dO should be non-contiguous (SBHD strides)" + + # Reference: use contiguous dO for the reference computation + do_contig = do.contiguous() + + # Triton forward + backward with SBHD-strided dO + with torch.enable_grad(): + triton_out = flash_attn_func(q, k, v, causal=CAUSAL) + triton_dq, triton_dk, triton_dv = torch.autograd.grad(triton_out, (q, k, v), do) + + # Reference forward + backward (with contiguous dO) + torch_out, torch_grads, fwd_tol, bwd_tols = _attention_ref_with_tol( + q, + k, + v, + do_contig, + causal=CAUSAL, + ) + torch_dq, torch_dk, torch_dv = torch_grads + + triton_vals = [triton_out, triton_dq, triton_dk, triton_dv] + ref_vals = [torch_out, torch_dq, torch_dk, torch_dv] + tols = [fwd_tol] + bwd_tols + for tri, ref, (atol, rtol) in zip(triton_vals, ref_vals, tols): + torch.testing.assert_close(tri, ref.to(tri.dtype), atol=atol, rtol=rtol) + + @pytest.mark.parametrize("SEQLEN_Q", [512, 2048]) @pytest.mark.parametrize("SEQLEN_K", [512, 2048]) @pytest.mark.parametrize("NUM_Q_HEADS", [32, 64])