Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def _bwd_dq_inner(
use_cuda_graph=True,
)
@triton.jit
def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q)
def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), batch)
Q, K, V, sm_scale, DO, DQ, DK, DV,
M, Delta,
stride_qb, stride_qh, stride_qm, stride_qd,
Expand Down Expand Up @@ -487,9 +487,9 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead
DEBUG_TRITON_DETAIL: tl.constexpr,
):
# program ids
pid = tl.program_id(0)
bid = tl.program_id(1)
hkid = tl.program_id(2)
hkid = tl.program_id(0)
pid = tl.program_id(1)
bid = tl.program_id(2)
if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701
# figure out varlen start and end
q_start = 0
Expand Down Expand Up @@ -843,9 +843,9 @@ def bwd_kernel_noncausal(
DEBUG_TRITON_DETAIL: tl.constexpr,
):
# program ids
pid = tl.program_id(0)
bid = tl.program_id(1)
hkid = tl.program_id(2)
hkid = tl.program_id(0)
pid = tl.program_id(1)
bid = tl.program_id(2)
if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701
# figure out varlen start and end
q_start = 0
Expand Down Expand Up @@ -1202,7 +1202,7 @@ def attention_prefill_backward_triton_split_oneKernel_impl(
dropout_mask.stride()

seqlen = max(max_seqlen_q_final, max_seqlen_k_final)
grid = lambda META: ((seqlen + META['BLOCK_N1'] - 1) // META['BLOCK_N1'], batch, nheads_k)
grid = lambda META: (nheads_k, (seqlen + META['BLOCK_N1'] - 1) // META['BLOCK_N1'], batch, )
if causal:
if DEBUG_TRITON: print(f"bwd_kernel: grid = {grid}" ) # noqa: E701
bwd_kernel_causal[grid](
Expand Down
Loading