diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py index 9f8a1ab46a2..089676f5b0b 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py @@ -449,7 +449,7 @@ def _bwd_dq_inner( use_cuda_graph=True, ) @triton.jit -def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) +def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), batch) Q, K, V, sm_scale, DO, DQ, DK, DV, M, Delta, stride_qb, stride_qh, stride_qm, stride_qd, @@ -487,9 +487,9 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead DEBUG_TRITON_DETAIL: tl.constexpr, ): # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) + hkid = tl.program_id(0) + pid = tl.program_id(1) + bid = tl.program_id(2) if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 # figure out varlen start and end q_start = 0 @@ -843,9 +843,9 @@ def bwd_kernel_noncausal( DEBUG_TRITON_DETAIL: tl.constexpr, ): # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) + hkid = tl.program_id(0) + pid = tl.program_id(1) + bid = tl.program_id(2) if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 # figure out varlen start and end q_start = 0 @@ -1202,7 +1202,7 @@ def attention_prefill_backward_triton_split_oneKernel_impl( dropout_mask.stride() seqlen = max(max_seqlen_q_final, max_seqlen_k_final) - grid = lambda META: ((seqlen + META['BLOCK_N1'] - 1) // META['BLOCK_N1'], batch, nheads_k) + grid = lambda META: (nheads_k, (seqlen + META['BLOCK_N1'] - 1) // META['BLOCK_N1'], batch, ) if causal: if DEBUG_TRITON: print(f"bwd_kernel: grid = {grid}" ) # noqa: E701 bwd_kernel_causal[grid](