Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def _bwd_preprocess(
stride_o_h,
stride_o_m,
stride_o_k,
stride_do_b,
stride_do_h,
stride_do_m,
stride_do_k,
stride_delta_b,
stride_delta_h,
stride_delta_m,
Expand Down Expand Up @@ -70,14 +74,22 @@ def _bwd_preprocess(
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_k = tl.arange(0, BLOCK_D_MODEL_POW2)

# Offset O/DO by batch, head and q_start
offs = (
# O and DO may have different strides (e.g. BSHD vs SBHD memory layout),
# so address each with its own strides.
offs_o = (
bid * stride_o_b
+ hid * stride_o_h
+ q_start * stride_o_m
+ offs_m[:, None] * stride_o_m
+ offs_k[None, :] * stride_o_k
)
offs_do = (
bid * stride_do_b
+ hid * stride_do_h
+ q_start * stride_do_m
+ offs_m[:, None] * stride_do_m
+ offs_k[None, :] * stride_do_k
)

# create masks
mask_m = offs_m < seqlen_q
Expand All @@ -87,8 +99,8 @@ def _bwd_preprocess(
mask &= offs_k[None, :] < BLOCK_D_MODEL

# load [BLOCK_M, BLOCK_D_MODEL_POW2]
o = tl.load(o_ptr + offs, mask=mask, other=0.0)
do = tl.load(do_ptr + offs, mask=mask, other=0.0)
o = tl.load(o_ptr + offs_o, mask=mask, other=0.0)
do = tl.load(do_ptr + offs_do, mask=mask, other=0.0)

# compute and write-back to delta
if IS_FP8:
Expand Down
20 changes: 16 additions & 4 deletions aiter/ops/triton/_triton_kernels/attention/mha_onekernel_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def _bwd_preprocess(
stride_o_h,
stride_o_m,
stride_o_k,
stride_do_b,
stride_do_h,
stride_do_m,
stride_do_k,
stride_delta_b,
stride_delta_h,
stride_delta_m,
Expand Down Expand Up @@ -75,14 +79,22 @@ def _bwd_preprocess(
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_k = tl.arange(0, BLOCK_D_MODEL_POW2)

# Offset O/DO by batch, head and q_start
offs = (
# O and DO may have different strides (e.g. BSHD vs SBHD memory layout),
# so address each with its own strides.
offs_o = (
bid * stride_o_b
+ hid * stride_o_h
+ q_start * stride_o_m
+ offs_m[:, None] * stride_o_m
+ offs_k[None, :] * stride_o_k
)
offs_do = (
bid * stride_do_b
+ hid * stride_do_h
+ q_start * stride_do_m
+ offs_m[:, None] * stride_do_m
+ offs_k[None, :] * stride_do_k
)

# create masks
mask_m = offs_m < seqlen_q
Expand All @@ -92,8 +104,8 @@ def _bwd_preprocess(
mask &= offs_k[None, :] < BLOCK_D_MODEL

# load [BLOCK_M, BLOCK_D_MODEL_POW2]
o = tl.load(o_ptr + offs, mask=mask, other=0.0)
do = tl.load(do_ptr + offs, mask=mask, other=0.0)
o = tl.load(o_ptr + offs_o, mask=mask, other=0.0)
do = tl.load(do_ptr + offs_do, mask=mask, other=0.0)

# compute and write-back to delta
if IS_FP8:
Expand Down
1 change: 1 addition & 0 deletions aiter/ops/triton/attention/mha_fused_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def flash_attn_fused_backward(
do,
delta,
*o_strides,
*do_strides,
*delta_strides,
descale_strides[3],
cu_seqlens_q,
Expand Down
1 change: 1 addition & 0 deletions aiter/ops/triton/attention/mha_onekernel_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def flash_attn_onekernel_backward(
do,
delta,
*o_strides,
*do_strides,
*delta_strides,
descale_strides[3],
cu_seqlens_q,
Expand Down
72 changes: 72 additions & 0 deletions op_tests/triton_tests/attention/test_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,78 @@ def test_mha_backward(
assert_cosine_similarity(tri, ref)


@pytest.mark.parametrize("BATCH", [2, 4])
@pytest.mark.parametrize("SEQLEN_Q", [256, 512])
@pytest.mark.parametrize("SEQLEN_K", [256, 512])
@pytest.mark.parametrize("NUM_Q_HEADS", [32])
@pytest.mark.parametrize("NUM_K_HEADS", [8])
@pytest.mark.parametrize("HEAD_SZ", [128])
@pytest.mark.parametrize("CAUSAL", [False])
@pytest.mark.parametrize("FUSED", [False, True])
def test_mha_backward_sbhd_do(
BATCH: int,
SEQLEN_Q: int,
SEQLEN_K: int,
NUM_Q_HEADS: int,
NUM_K_HEADS: int,
HEAD_SZ: int,
CAUSAL: bool,
FUSED: bool,
dtype=torch.float16,
):
"""Verify backward correctness when dO has SBHD memory layout (strides differ from O).

Creates dO as a (seqlen, batch, nheads, headdim) tensor transposed to
(batch, seqlen, nheads, headdim), so its strides are different from the
contiguous BSHD output tensor. This exercises the independent stride
handling for dO in _bwd_preprocess.
"""
torch.cuda.empty_cache()
torch.manual_seed(42)
if FUSED and CAUSAL:
pytest.skip("FUSED+CAUSAL results in NaNs")

mha_set_use_fused_bwd_kernel(FUSED)

q = torch.randn(BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ, device="cuda", dtype=dtype)
k = torch.randn(BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ, device="cuda", dtype=dtype)
v = torch.randn(BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ, device="cuda", dtype=dtype)
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True

# dO in SBHD memory layout: (seqlen, batch, nheads, headdim) viewed as BSHD
do_sbhd = torch.randn(
SEQLEN_Q, BATCH, NUM_Q_HEADS, HEAD_SZ, device="cuda", dtype=dtype
)
do = do_sbhd.transpose(0, 1) # shape is BSHD but strides are SBHD
assert not do.is_contiguous(), "dO should be non-contiguous (SBHD strides)"

# Reference: use contiguous dO for the reference computation
do_contig = do.contiguous()

# Triton forward + backward with SBHD-strided dO
with torch.enable_grad():
triton_out = flash_attn_func(q, k, v, causal=CAUSAL)
triton_dq, triton_dk, triton_dv = torch.autograd.grad(triton_out, (q, k, v), do)

# Reference forward + backward (with contiguous dO)
torch_out, torch_grads, fwd_tol, bwd_tols = _attention_ref_with_tol(
q,
k,
v,
do_contig,
causal=CAUSAL,
)
torch_dq, torch_dk, torch_dv = torch_grads

triton_vals = [triton_out, triton_dq, triton_dk, triton_dv]
ref_vals = [torch_out, torch_dq, torch_dk, torch_dv]
tols = [fwd_tol] + bwd_tols
for tri, ref, (atol, rtol) in zip(triton_vals, ref_vals, tols):
torch.testing.assert_close(tri, ref.to(tri.dtype), atol=atol, rtol=rtol)


@pytest.mark.parametrize("SEQLEN_Q", [512, 2048])
@pytest.mark.parametrize("SEQLEN_K", [512, 2048])
@pytest.mark.parametrize("NUM_Q_HEADS", [32, 64])
Expand Down
Loading