diff --git a/.gitignore b/.gitignore index 9211dcae38b..21d5dda524a 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,4 @@ training/data # ck modules csrc/composable_kernel csrc/cutlass +.analysis \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py index 8bdcfd10d6a..5b2f8858d11 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py @@ -99,8 +99,7 @@ def get_autotune_configs(): # Here is the I/O shape: # Out: (batch, nhead_q, max_seqlens_q, headDim) # DO: (batch, nhead_q, max_seqlens_q, headDim) -# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at -# fwd_prefill.py line 607 +# Delta: (batch, nheads_q, max_seqlens_q) @triton.autotune( configs=preprocess_autotune_configs, key=preprocess_autotune_keys, @@ -108,9 +107,11 @@ def get_autotune_configs(): ) @triton.jit def _bwd_preprocess( - O, DO, # noqa: E741 + O, + DO, # noqa: E741 Delta, stride_ob, stride_oh, stride_om, stride_od, + stride_dob, stride_doh, stride_dom, stride_dod, stride_delta_b, stride_delta_h, stride_delta_m, stride_descale_do_z, cu_seqlens_q, max_seqlen_q, @@ -125,8 +126,6 @@ def _bwd_preprocess( bid = tl.program_id(1) hid = tl.program_id(2) # Handle varlen - q_start = 0 - seqlen_q = max_seqlen_q if IS_VARLEN: q_start = tl.load(cu_seqlens_q + bid) q_end = tl.load(cu_seqlens_q + bid + 1) @@ -138,32 +137,41 @@ def _bwd_preprocess( # Compute offsets offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) offs_d = tl.arange(0, HEAD_DIM) - # Offset O/DO by batch, head and q_start - O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741 - DO += bid * stride_ob + hid * stride_oh + q_start * stride_om + # pointer offsets for O & DO + off_o = ( bid * stride_ob + + hid * stride_oh + + q_start * stride_om + + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_od) # noqa: E741 + off_do = (bid * stride_dob + + hid * stride_doh + + q_start * stride_dom + + offs_m[:, None] * stride_dom + + offs_d[None, :] * stride_dod) + # create masks mask_m = offs_m < seqlen_q mask_md = mask_m[:, None] PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) if PADDED_HEAD: mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM - # compute pointers - offs_do = offs_m[:, None] * stride_om + offs_d[None, :] * stride_od - out_ptrs = O + offs_do - do_ptrs = DO + offs_do # load - o = tl.load(out_ptrs, mask=mask_md, other=0.0) - do = tl.load(do_ptrs, mask=mask_md, other=0.0) + o = tl.load(O + off_o, mask=mask_md, other=0.0) + do = tl.load(DO + off_do, mask=mask_md, other=0.0) # compute and write-back to delta if IS_FP8: - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hid) + off_descale_do = bid * stride_descale_do_z + hid + descale_do = tl.load(Descale_do + off_descale_do) # NOTE: do is in the fp8 range and o is not in fp8 delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) else: delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - delta_offset = Delta + bid * stride_delta_b + hid * stride_delta_h + q_start * stride_delta_m - tl.store(delta_offset + offs_m * stride_delta_m, delta, mask=mask_m) + off_delta = (bid * stride_delta_b + + hid * stride_delta_h + + q_start * stride_delta_m + + offs_m * stride_delta_m) + tl.store(Delta + off_delta , delta, mask=mask_m) # The main inner-loop logic for computing dK and dV. @@ -1063,8 +1071,9 @@ def is_contiguous(x, name): print(f"{name} is not contiguous") return x.contiguous() - -OLD_LSE = os.environ.get('OLD_LSE', '0').lower() in ('1', 'true', 'yes') +OLD_LSE: bool = False +DEBUG_TRITON: bool = False +DEBUG_TRITON_DETAIL: bool = False def attention_prefill_backward_triton_split_fused_no_atomics_impl( do: torch.Tensor, @@ -1098,25 +1107,118 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( descale_dk: Optional[torch.Tensor], descale_dv: Optional[torch.Tensor], ): - # debug - DEBUG_TRITON: bool = False - DEBUG_TRITON_DETAIL: bool = False - - # do = is_contiguous(do, "do") - # q = is_contiguous(q, "q") - # k = is_contiguous(k, "k") - # v = is_contiguous(v, "v") - # o = is_contiguous(o, "o") - # softmax_lse = is_contiguous(softmax_lse, "softmax_lse") - # dq = is_contiguous(dq, "dq") - # dk = is_contiguous(dk, "dk") - # dv = is_contiguous(dv, "dv") + # get params, strides and shape + IS_VARLEN = layout == "thd" + use_dropout = (dropout_p > 0.0) + + # common assertions + assert 0.0 <= dropout_p <= 1.0, f"dropout_p must be between 0 and 1, got {dropout_p}" + assert q.device == k.device == v.device == o.device == do.device == softmax_lse.device, \ + f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}, do={do.device}, softmax_lse={softmax_lse.device}" + assert q.dtype == k.dtype == v.dtype == do.dtype, "q, k, v, do must have the same dtype" + current_device = torch.cuda.current_device() + assert q.is_cuda and q.device.index == current_device, f"Device mismatch: Kernel will launch on cuda:{current_device}, but tensors are on {q.device}" + # get shapes and strides + if IS_VARLEN: + # shape + total_seqlen_q, nheads_q, head_size_q = q.shape + total_seqlen_k, nheads_k, head_size_k = k.shape + total_seqlen_v, nheads_v, head_size_v = v.shape + nheads_lse, total_seqlen_lse = softmax_lse.shape + + # assert shapes + assert total_seqlen_lse == total_seqlen_q, f"softmax_lse seqlen {total_seqlen_lse} != q seqlen {total_seqlen_q}" + assert cu_seqlens_q is not None, "cu_seqlens_q must be provided for varlen layout" + assert cu_seqlens_k is not None, "cu_seqlens_k must be provided for varlen layout" + assert max_seqlen_q is not None, "max_seqlen_q must be provided for varlen layout" + assert max_seqlen_k is not None, "max_seqlen_k must be provided for varlen layout" + + # assert head dimensions + assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + assert nheads_lse == nheads_q, f"softmax_lse heads {nheads_lse} != q heads {nheads_q}" + + # assert output shapes + assert o.shape == (total_seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_q)}" + assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" + assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}" + + # assert cu_seqlens + assert cu_seqlens_q.dtype == torch.int32, f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" + assert cu_seqlens_k.dtype == torch.int32, f"cu_seqlens_k must be int32, got {cu_seqlens_k.dtype}" + assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" + assert cu_seqlens_k[0] == 0, "cu_seqlens_k must start with 0" + assert cu_seqlens_q[-1] == total_seqlen_q, f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} != total_seqlen_q {total_seqlen_q}" + assert cu_seqlens_k[-1] == total_seqlen_k, f"cu_seqlens_k[-1] {cu_seqlens_k[-1]} != total_seqlen_k {total_seqlen_k}" + + # set vars + batch = len(cu_seqlens_q) - 1 + head_size = head_size_q + + # strides + stride_qb, stride_qm, stride_qh, stride_qd = 0, q.stride(0), q.stride(1), q.stride(2) + stride_kb, stride_kn, stride_kh, stride_kd = 0, k.stride(0), k.stride(1), k.stride(2) + stride_vb, stride_vn, stride_vh, stride_vd = 0, v.stride(0), v.stride(1), v.stride(2) + stride_ob, stride_om, stride_oh, stride_od = 0, o.stride(0), o.stride(1), o.stride(2) + stride_dqb, stride_dqm, stride_dqh, stride_dqd = 0, dq.stride(0), dq.stride(1), dq.stride(2) + stride_dkb, stride_dkn, stride_dkh, stride_dkd = 0, dk.stride(0), dk.stride(1), dk.stride(2) + stride_dvb, stride_dvn, stride_dvh, stride_dvd = 0, dv.stride(0), dv.stride(1), dv.stride(2) + stride_dob, stride_dom, stride_doh, stride_dod = 0, do.stride(0), do.stride(1), do.stride(2) + stride_lse_b, stride_lse_h, stride_lse_m = (0, softmax_lse.stride(0), softmax_lse.stride(1)) + else: + # shapes + batch_q, seqlen_q, nheads_q, head_size_q = q.shape + batch_k, seqlen_k, nheads_k, head_size_k = k.shape + batch_v, seqlen_v, nheads_v, head_size_v = v.shape + batch_lse, nheads_lse, seqlen_lse = softmax_lse.shape + + # assert batch dimensions + assert batch_q == batch_k == batch_v, f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" + + # assert head dimensions + assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + + # assert sequence lengths + assert seqlen_k == seqlen_v, f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" + + # assert output shapes + assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected" + assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" + assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}" + + # assert softmax_lse shape + assert softmax_lse.shape == (batch_q, nheads_q, seqlen_q), f"softmax_lse shape {softmax_lse.shape} != expected" + + # set vars + batch = batch_q + head_size = head_size_q + max_seqlen_q = seqlen_q + max_seqlen_k = seqlen_k + + # strides + stride_qb, stride_qm, stride_qh, stride_qd = q.stride() + stride_kb, stride_kn, stride_kh, stride_kd = k.stride() + stride_vb, stride_vn, stride_vh, stride_vd = v.stride() + stride_ob, stride_om, stride_oh, stride_od = o.stride() + stride_dqb, stride_dqm, stride_dqh, stride_dqd = dq.stride() + stride_dkb, stride_dkn, stride_dkh, stride_dkd = dk.stride() + stride_dvb, stride_dvn, stride_dvh, stride_dvd = dv.stride() + stride_dob, stride_dom, stride_doh, stride_dod = do.stride() + stride_lse_b, stride_lse_h, stride_lse_m = softmax_lse.stride() + + # fp8 setup - moved after all assertions IS_FP8 = is_fp8(q) if IS_FP8: FP8_MAX = torch.finfo(q.dtype).max - # assert that the main inputs are fp8 - assert is_fp8(do) and is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: do.dtype={do.dtype}, q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8." + # we already asserted that do, q, k, v all have the same dtype, so no need to check each one if is_fp8(o): FP8_OUTPUT = True assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." @@ -1136,45 +1238,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( FP8_OUTPUT = False stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = stride_descale_do_z = None - - # get params, strides and shape - IS_VARLEN = layout == "thd" - use_dropout = (dropout_p > 0.0) - - # get shapes and strides - if IS_VARLEN: - # shape - _, nheads_q, head_size = q.shape - _, nheads_k, _ = k.shape - batch = len(cu_seqlens_q) - 1 - max_seqlen_q_final = max_seqlen_q - max_seqlen_k_final = max_seqlen_k - - # strides - stride_qb, stride_qh, stride_qm, stride_qd = 0, q.stride(1), q.stride(0), q.stride(2) - stride_kb, stride_kh, stride_kn, stride_kd = 0, k.stride(1), k.stride(0), k.stride(2) - stride_vb, stride_vh, stride_vn, stride_vd = 0, v.stride(1), v.stride(0), v.stride(2) - stride_ob, stride_oh, stride_om, stride_od = 0, o.stride(1), o.stride(0), o.stride(2) - stride_dqb, stride_dqh, stride_dqm, stride_dqd = 0, dq.stride(1), dq.stride(0), dq.stride(2) - stride_dkb, stride_dkh, stride_dkn, stride_dkd = 0, dk.stride(1), dk.stride(0), dk.stride(2) - stride_dvb, stride_dvh, stride_dvn, stride_dvd = 0, dv.stride(1), dv.stride(0), dv.stride(2) - stride_dob, stride_doh, stride_dom, stride_dod = 0, do.stride(1), do.stride(0), do.stride(2) - stride_lse_b, stride_lse_h, stride_lse_m = (0, softmax_lse.stride(0), softmax_lse.stride(1)) - else: - # shapes - batch, max_seqlen_q_final, nheads_q, head_size = q.shape - _, max_seqlen_k_final, nheads_k, _ = k.shape - - # strides - stride_qb, stride_qh, stride_qm, stride_qd = q.stride(0), q.stride(2), q.stride(1), q.stride(3) - stride_kb, stride_kh, stride_kn, stride_kd = k.stride(0), k.stride(2), k.stride(1), k.stride(3) - stride_vb, stride_vh, stride_vn, stride_vd = v.stride(0), v.stride(2), v.stride(1), v.stride(3) - stride_ob, stride_oh, stride_om, stride_od = o.stride(0), o.stride(2), o.stride(1), o.stride(3) - stride_dqb, stride_dqh, stride_dqm, stride_dqd = dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3) - stride_dkb, stride_dkh, stride_dkn, stride_dkd = dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3) - stride_dvb, stride_dvh, stride_dvn, stride_dvd = dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3) - stride_dob, stride_doh, stride_dom, stride_dod = do.stride(0), do.stride(2), do.stride(1), do.stride(3) - stride_lse_b, stride_lse_h, stride_lse_m = softmax_lse.stride() + # alibi setup use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) # get closest power of 2 over or equal to 32. @@ -1193,28 +1257,28 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( else: if IS_VARLEN: # interface expects the varlen sequence dims to rounded like this. Not sure why. - batch_size = cu_seqlens_q.numel() - 1 total_q, num_heads, _ = q.shape - total_q_rounded = total_q + 128 * batch_size + total_q_rounded = total_q + 128 * batch delta_padded = torch.zeros((nheads_q, total_q_rounded), device=q.device, dtype=torch.float32) delta = delta_padded[:, :total_q] stride_delta_b, stride_delta_h, stride_delta_m = 0, delta.stride(0), delta.stride(1) else: # the interface expects the sequence dimension to be rounded to 128 - max_seqlen_q_rounded = round_multiple(max_seqlen_q_final, 128) + max_seqlen_q_rounded = round_multiple(max_seqlen_q, 128) delta_padded = torch.zeros((batch, nheads_q, max_seqlen_q_rounded), - device=softmax_lse.device, dtype=torch.float32) - delta = delta_padded[:, :, :max_seqlen_q_final] + device=q.device, dtype=torch.float32) + delta = delta_padded[:, :, :max_seqlen_q] stride_delta_b, stride_delta_h, stride_delta_m = delta.stride() - pre_grid = lambda META: (triton.cdiv(max_seqlen_q_final, META['PRE_BLOCK']), batch, nheads_q) + pre_grid = lambda META: (triton.cdiv(max_seqlen_q, META['PRE_BLOCK']), batch, nheads_q) _bwd_preprocess[pre_grid]( o, do, delta, stride_ob, stride_oh, stride_om, stride_od, + stride_dob, stride_doh, stride_dom, stride_dod, stride_delta_b, stride_delta_h, stride_delta_m, stride_descale_do_z, - cu_seqlens_q, max_seqlen_q_final, + cu_seqlens_q, max_seqlen_q, descale_do, HEAD_DIM=HEAD_DIM, ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, @@ -1232,7 +1296,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( (0, 0 , 0 , 0) if use_dropout: dropout_mask = torch.zeros( - (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), + (batch, nheads_q, max_seqlen_q, max_seqlen_k), device=q.device, dtype=torch.float32 ) @@ -1241,7 +1305,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( if not IS_VARLEN: dropout_mask = create_dropout_mask( dropout_p, - (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), + (batch, nheads_q, max_seqlen_q, max_seqlen_k), seed = philox_seed ) else: @@ -1252,7 +1316,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ dropout_mask.stride() - seqlen = max(max_seqlen_q_final, max_seqlen_k_final) + seqlen = max(max_seqlen_q, max_seqlen_k) grid = lambda META: (nheads_k, (seqlen + META['BLOCK_N1'] - 1) // META['BLOCK_N1'], batch, ) if causal: if DEBUG_TRITON: print(f"bwd_kernel: grid = {grid}" ) # noqa: E701 @@ -1273,7 +1337,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, + max_seqlen_q, max_seqlen_k, dropout_mask, dropout_p, philox_seed, philox_offset, alibi_slopes, descale_q, descale_k, descale_v, descale_do, @@ -1307,7 +1371,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, + max_seqlen_q, max_seqlen_k, dropout_mask, dropout_p, philox_seed, philox_offset, alibi_slopes, descale_q, descale_k, descale_v, descale_do, diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 59fe8bfaf4e..b646b486ce1 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -526,19 +526,14 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr, FLIP_GRID: tl.constexpr): + IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr): # set params ACCUMULATOR_TYPE = tl.float32 # compute offsets - if FLIP_GRID: - off_z = tl.program_id(0) - off_h_q = tl.program_id(1) - start_m = tl.program_id(2) - else: - start_m = tl.program_id(0) - off_h_q = tl.program_id(1) - off_z = tl.program_id(2) + off_z = tl.program_id(0) + off_h_q = tl.program_id(1) + start_m = tl.program_id(2) # If MQA / GQA, set the K and V head offsets appropriately. GROUP_SIZE: tl.constexpr = HQ // HK if GROUP_SIZE != 1: @@ -899,19 +894,116 @@ def attention_prefill_forward_triton_impl( descale_v: Optional[torch.Tensor], descale_o: Optional[torch.Tensor], ): + # get params, strides and shape + IS_VARLEN = layout == "thd" + + # common assertions + assert 0.0 <= dropout_p <= 1.0, f"dropout_p must be between 0 and 1, got {dropout_p}" + assert q.device == k.device == v.device == o.device, \ + f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}" + assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype" + current_device = torch.cuda.current_device() + assert q.is_cuda and q.device.index == current_device, f"Device mismatch: Kernel will launch on cuda:{current_device}, but tensors are on {q.device}" + + # get shapes and strides + if IS_VARLEN: + # shape + total_seqlen_q, nheads_q, head_size_q = q.shape + total_seqlen_k, nheads_k, head_size_k = k.shape + total_seqlen_v, nheads_v, head_size_v = v.shape + + # assert shapes + assert cu_seqlens_q is not None, "cu_seqlens_q must be provided for varlen layout" + assert cu_seqlens_k is not None, "cu_seqlens_k must be provided for varlen layout" + assert max_seqlens_q is not None and max_seqlens_q > 0, "max_seqlens_q must be provided and positive for varlen layout" + assert max_seqlens_k is not None and max_seqlens_k > 0, "max_seqlens_k must be provided and positive for varlen layout" + + # assert head dimensions + assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + + # assert output shapes + assert o.shape == (total_seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_q)}" + + # assert cu_seqlens + assert cu_seqlens_q.dtype == torch.int32, f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" + assert cu_seqlens_k.dtype == torch.int32, f"cu_seqlens_k must be int32, got {cu_seqlens_k.dtype}" + assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" + assert cu_seqlens_k[0] == 0, "cu_seqlens_k must start with 0" + assert cu_seqlens_q[-1] == total_seqlen_q, f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} != total_seqlen_q {total_seqlen_q}" + assert cu_seqlens_k[-1] == total_seqlen_k, f"cu_seqlens_k[-1] {cu_seqlens_k[-1]} != total_seqlen_k {total_seqlen_k}" + + # set vars + batch = len(cu_seqlens_q) - 1 + head_size = head_size_q + + # softmax_lse shape + softmax_lse = torch.zeros((nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32) + + # strides + stride_qb, stride_qh, stride_qm, stride_qd = 0, q.stride(1), q.stride(0), q.stride(2) + stride_kb, stride_kh, stride_kn, stride_kd = 0, k.stride(1), k.stride(0), k.stride(2) + stride_vb, stride_vh, stride_vn, stride_vd = 0, v.stride(1), v.stride(0), v.stride(2) + stride_ob, stride_oh, stride_om, stride_od = 0, o.stride(1), o.stride(0), o.stride(2) + stride_lse_z, stride_lse_h, stride_lse_m = 0, softmax_lse.stride(0), softmax_lse.stride(1) + else: + # shapes + batch_q, seqlen_q, nheads_q, head_size_q = q.shape + batch_k, seqlen_k, nheads_k, head_size_k = k.shape + batch_v, seqlen_v, nheads_v, head_size_v = v.shape + + # assert batch dimensions + assert batch_q == batch_k == batch_v, f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" + + # assert head dimensions + assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + + # assert sequence lengths + assert seqlen_k == seqlen_v, f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" + + # assert output shapes + assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(batch_q, seqlen_q, nheads_q, head_size_q)}" + + # set vars + batch = batch_q + head_size = head_size_q + max_seqlens_q = seqlen_q + max_seqlens_k = seqlen_k + + # softmax_lse shape + softmax_lse = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) + + # strides + stride_qb, stride_qh, stride_qm, stride_qd = q.stride(0), q.stride(2), q.stride(1), q.stride(3) + stride_kb, stride_kh, stride_kn, stride_kd = k.stride(0), k.stride(2), k.stride(1), k.stride(3) + stride_vb, stride_vh, stride_vn, stride_vd = v.stride(0), v.stride(2), v.stride(1), v.stride(3) + stride_ob, stride_oh, stride_om, stride_od = o.stride(0), o.stride(2), o.stride(1), o.stride(3) + stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() + + # fp8 setup and assertions IS_FP8 = is_fp8(q) if IS_FP8: - FP8_MAX: tl.constexpr = torch.finfo(q.dtype).max + # we already asserted that q, k, v all have the same dtype, so no need to check each one - assert is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8." + FP8_MAX = torch.finfo(q.dtype).max + # Check descale tensors + assert descale_q is not None, "descale_q must be provided when using fp8" + assert descale_k is not None, "descale_k must be provided when using fp8" + assert descale_v is not None, "descale_v must be provided when using fp8" + if is_fp8(o): FP8_OUTPUT = True - assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor for the output." + assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." else: FP8_OUTPUT = False + # o should be fp32 or fp16/bf16 + assert o.dtype in [torch.float16, torch.bfloat16, torch.float32], \ + f"Output tensor o must be fp16, bf16, or fp32 when using fp8, got {o.dtype}" - # Get strides for the kernel stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None @@ -921,65 +1013,23 @@ def attention_prefill_forward_triton_impl( FP8_OUTPUT = False descale_q = descale_k = descale_v = descale_o = None stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = None - - # check flags - IS_VARLEN = layout == "thd" + + # check output dtype matches input dtype when not using fp8 + assert o.dtype == q.dtype, f"Output dtype {o.dtype} must match input dtype {q.dtype} when not using fp8" + + # check features use_sliding_window = window_size_left != -1 or window_size_right!= -1 use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) - is_inference = False if cache_seqlens is None else True - if is_inference: - assert layout == "bshd", f"{layout} layout is not supported with inference. Use bshd layout" - if DEBUG: - print(f"is_inference:", is_inference) - # NOTE: a large bias tensor leads to overflow during pointer arithmetic if (bias is not None): assert (bias.numel() < 2**31) - # get shape and strides - if IS_VARLEN: # thd layout - # shape - total_q, nheads_q, head_size = q.shape - _, nheads_k, _ = k.shape - assert cu_seqlens_q is not None - batch = len(cu_seqlens_q) - 1 - - # softmax_lse is the log of the normalization constant / sum of expoential score(unnormalzied probablities) - softmax_lse = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) - - # strides - stride_qb, stride_qh, stride_qm, stride_qd = 0, q.stride(1), q.stride(0), q.stride(2) - stride_kb, stride_kh, stride_kn, stride_kd = 0, k.stride(1), k.stride(0), k.stride(2) - stride_vb, stride_vh, stride_vn, stride_vd = 0, v.stride(1), v.stride(0), v.stride(2) - stride_ob, stride_oh, stride_om, stride_od = 0, o.stride(1), o.stride(0), o.stride(2) - stride_lse_z, stride_lse_h, stride_lse_m = 0, softmax_lse.stride(0), softmax_lse.stride(1) - else: # bshd layout - # shape - batch, seqlen_q, nheads_q, head_size = q.shape - _, _, nheads_k, _ = k.shape - - # softmax_lse is the log of the normalization constant / sum of expoential score(unnormalzied probablities) - softmax_lse = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) - - # strides - stride_qb, stride_qh, stride_qm, stride_qd = q.stride(0), q.stride(2), q.stride(1), q.stride(3) - stride_kb, stride_kh, stride_kn, stride_kd = k.stride(0), k.stride(2), k.stride(1), k.stride(3) - stride_vb, stride_vh, stride_vn, stride_vd = v.stride(0), v.stride(2), v.stride(1), v.stride(3) - stride_ob, stride_oh, stride_om, stride_od = o.stride(0), o.stride(2), o.stride(1), o.stride(3) - stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() - # Get closest power of 2 over or equal to 32. padded_d_model = 1 << (head_size - 1).bit_length() # Smallest head_dim supported is 16. If smaller, the tile in the # kernel is padded - there is no padding in memory for any dims. padded_d_model = max(padded_d_model, 16) - FLIP_GRID = True - if FLIP_GRID: - grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META['BLOCK_M'])) - else: - grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) - # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing @@ -999,13 +1049,14 @@ def attention_prefill_forward_triton_impl( dropout_mask = None stride_sz, stride_sh, stride_sm, stride_sn = (0, 0, 0, 0) - if bias is not None: stride_bz, stride_bh, stride_bm, stride_bn = (bias.stride(0), bias.stride(1),bias.stride(2), bias.stride(3)) else: stride_bz, stride_bh, stride_bm, stride_bn = (0, 0, 0, 0) + # launch kernel + grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META['BLOCK_M'])) attn_fwd[grid](q, k, v, bias, cache_seqlens, cache_batch_idx, descale_q, descale_k, descale_v, descale_o, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, sm_scale, softmax_lse, o, @@ -1025,6 +1076,6 @@ def attention_prefill_forward_triton_impl( IS_VARLEN=IS_VARLEN, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, - IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT, FLIP_GRID=FLIP_GRID) + IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT) return softmax_lse, sd_mask if return_softmax else None \ No newline at end of file