diff --git a/.github/workflows/amd_nightly.yml b/.github/workflows/amd_nightly.yml index fdc0453413c..3131496ac49 100644 --- a/.github/workflows/amd_nightly.yml +++ b/.github/workflows/amd_nightly.yml @@ -21,7 +21,7 @@ jobs: timeout-minutes: 720 # self hosted runners can run jobs for longer than the default of 360 minutes container: image: rocm/pytorch:latest - options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --shm-size 16G --group-add video --user root steps: - name: Checkout uses: actions/checkout@v4 @@ -38,7 +38,7 @@ jobs: - name: Install Triton run: | - pip install triton==3.2.0 + pip install triton==3.3.0 - name: Show Triton version run: | @@ -50,7 +50,7 @@ jobs: - name: Install dependencies for bench and misc run: | - pip install numpy==1.24 matplotlib pandas tabulate + pip install matplotlib pandas tabulate - name: AMD Internal Tests run: | @@ -58,7 +58,7 @@ jobs: - name: Flash Attention Tests run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest tests/test_flash_attn_triton_amd.py + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest -n 8 tests/test_flash_attn_triton_amd.py - name: AMD Bench run: | @@ -90,7 +90,7 @@ jobs: - name: Install Triton run: | - pip install triton==3.2.0 + pip install triton==3.3.0 - name: Show Triton version run: | diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml index 2122e680458..2f49567f960 100644 --- a/.github/workflows/amd_tests.yml +++ b/.github/workflows/amd_tests.yml @@ -19,7 +19,7 @@ jobs: timeout-minutes: 720 # self hosted runners can run jobs for longer than the default of 360 minutes container: image: rocm/pytorch:latest - options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --shm-size 16G --group-add video --user root steps: - name: Checkout uses: actions/checkout@v4 @@ -36,7 +36,7 @@ jobs: - name: Install Triton run: | - pip install triton==3.2.0 + pip install triton==3.3.0 - name: Show Triton version run: | @@ -48,7 +48,7 @@ jobs: - name: Install dependencies for bench and misc run: | - pip install numpy==1.24 matplotlib pandas tabulate + pip install matplotlib pandas tabulate - name: AMD Internal Tests run: | @@ -56,7 +56,7 @@ jobs: - name: Flash Attention Tests run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest tests/test_flash_attn_triton_amd.py + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest -n 8 tests/test_flash_attn_triton_amd.py - name: AMD Bench run: | diff --git a/README.md b/README.md index 6fed22c9a8a..8db04b36ad0 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,7 @@ To get started with the triton backend for AMD, follow the steps below. First install the recommended Triton version ``` -pip install triton==3.2.0 +pip install triton==3.3.0 ``` Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. @@ -182,7 +182,7 @@ FROM rocm/pytorch:latest WORKDIR /workspace # install triton -RUN pip install triton==3.2.0 +RUN pip install triton==3.3.0 # install flash attention ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" diff --git a/flash_attn/flash_attn_triton_amd/Dockerfile b/flash_attn/flash_attn_triton_amd/Dockerfile index 29a2c0c43ec..8df939a0886 100644 --- a/flash_attn/flash_attn_triton_amd/Dockerfile +++ b/flash_attn/flash_attn_triton_amd/Dockerfile @@ -3,7 +3,7 @@ FROM rocm/pytorch:latest WORKDIR /workspace # install triton -RUN pip install triton==3.2.0 +RUN pip install triton==3.3.0 # install flash attention ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" diff --git a/flash_attn/flash_attn_triton_amd/README.md b/flash_attn/flash_attn_triton_amd/README.md index 2d8fd8e70f3..f3a5db67fc5 100644 --- a/flash_attn/flash_attn_triton_amd/README.md +++ b/flash_attn/flash_attn_triton_amd/README.md @@ -28,7 +28,7 @@ To get started with the triton backend for AMD, follow the steps below. First install the recommended Triton version ``` -pip install triton==3.2.0 +pip install triton==3.3.0 ``` Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. @@ -56,7 +56,7 @@ FROM rocm/pytorch:latest WORKDIR /workspace # install triton -RUN pip install triton==3.2.0 +RUN pip install triton==3.3.0 # install flash attention ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" diff --git a/flash_attn/flash_attn_triton_amd/bench.py b/flash_attn/flash_attn_triton_amd/bench.py index 05e64c349be..f2b2e7d11d6 100755 --- a/flash_attn/flash_attn_triton_amd/bench.py +++ b/flash_attn/flash_attn_triton_amd/bench.py @@ -80,7 +80,7 @@ class EnvVariableConfig: backend: Optional[Literal["triton", "ck"]] = None ENV_VARIABLE_CONFIGS : List[EnvVariableConfig] = [ - EnvVariableConfig(key="BWD_MODE", values=["split", "fused", "jingning"], backend="triton"), + # EnvVariableConfig(key="BWD_MODE", values=["split", "fused", "jingning"], backend="triton"), ] class FunctionConfig: @@ -871,8 +871,8 @@ def load_flash_attn_module(backend: Literal["triton", "ck"], env_configs: Dict = # set environment variable for the desired backend if backend == "triton": os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE" - os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "0" os.environ["FLASH_ATTENTION_TRITON_AMD_DEBUG"] = "0" + os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "1" elif backend == "ck": os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "FALSE" else: @@ -1016,15 +1016,30 @@ def get_input_config_set(config_type): # batch, hq, hk, sq, sk, d_head, causal, dropout input_configs = [ # LLaMA 3 8B - (1, 32, 8, 8192, 8192, 128, True, 0.0), + (4, 32, 8, 8192, 8192, 128, True, 0.0), # LLaMA 3 70B - (1, 64, 8, 8192, 8192, 128, True, 0.0), + (4, 64, 8, 8192, 8192, 128, True, 0.0), ] else: raise ValueError(f"Unknown input config: {config_type}") return input_configs +def filter_backends(requested_backends, supported_backends, fn_name): + if requested_backends: + selected = [] + for be in requested_backends: + if be in supported_backends: + selected.append(be) + else: + warning( + f"backend '{be}' requested but not supported by " + f"function '{fn_name}'. skipping this back-end." + ) + return selected + else: + return supported_backends + def process_args(): """ @@ -1052,6 +1067,14 @@ def process_args(): default=None, help=f"Benchmarking mode(s) to run. If omitted, runs all supported modes for each function.", ) + parser.add_argument( + "--backend", + type=str, + nargs='*', + choices=["triton", "ck"], + default=None, + help="Back-end(s) to run (triton, ck). Omit to run every back-end that is both available and supported by the function.", + ) # config parser.add_argument("-b", type=int, default=None, help="Batch size") parser.add_argument("-hq", type=int, default=None, help="Q Number of heads") @@ -1067,7 +1090,8 @@ def process_args(): # parse function args benchmark_fns = args.benchmark_fn - requested_modes = args.mode + requested_modes = args.mode + requested_backends = args.backend # fenerate function configurations and input configurations separately all_function_configs = [] @@ -1101,9 +1125,17 @@ def process_args(): if not modes_to_run: warning(f"No valid modes to run for function '{fn_name}' based on request and function support. Skipping this function.") continue + + # filter by backend + backends_to_run = filter_backends(requested_backends, + supported_backends, + fn_name) + if not backends_to_run: + warning(f"no valid back-ends left for '{fn_name}'. skipping.") + continue # create a function config for each backend and dtype combination - for backend in supported_backends: + for backend in backends_to_run: for dtype in supported_dtypes: for mode in modes_to_run: for env_config in supported_env_configs[backend]: @@ -1124,6 +1156,52 @@ def check_environment_variables(): if key in os.environ: raise ValueError(f"Running with {key} environment variable is not recommended for the benching script. Use --help to see how to use the benching script.") +def compute_flops(batch, hq, hk, sq, sk, d_head, causal): + # 2 FLOPs per multiply‑add + if causal: + valid_pairs = ((sk * (sk + 1)) // 2 if sq > sk else + sq * sk - (sq * (sq - 1)) // 2) + else: + valid_pairs = sq * sk + return 2 * batch * hq * valid_pairs * d_head + +# see ref, https://github.com/ROCm/aiter/blob/jukorhon/mha-bwd/op_benchmarks/triton/bench_mha.py +def _flops_single_row(row: pd.Series, mode: str) -> float: + b, hq, d_head = int(row["BATCH"]), int(row["HQ"]), int(row["D_HEAD"]) + sq, sk = int(row["N_CTX_Q"]), int(row["N_CTX_K"]) + causal = bool(row["CAUSAL"]) + + # -------- number of (query, key) products per head ---------------- + if not causal: + valid_pairs = sq * sk + else: # triangular mask + if sq > sk: + valid_pairs = sk * (sk + 1) // 2 + (sq - sk) * sk + else: # sq <= sk + valid_pairs = sq * (sq + 1) // 2 + + # one matmul FLOPs (mul + add) = 2 · m · n · k + flops_per_matmul = 2.0 * b * hq * valid_pairs * d_head + total_flops = 2.0 * flops_per_matmul # 2 matmuls in forward + + if mode == "fwd": + pass + elif mode == "bwd": + total_flops *= 2.5 # 2·bwd + 0.5·recompute + elif mode == "full": + total_flops *= 3.5 # fwd + bwd + else: + raise ValueError(f"unknown mode {mode}") + + return total_flops + +def add_tflops_columns(df: pd.DataFrame, func_cfg: FunctionConfig) -> pd.DataFrame: + ms_col = func_cfg.column_name() + tf_col = ms_col.replace("_ms", "_tflops") + flops = df.apply(_flops_single_row, axis=1, mode=func_cfg.mode) + df[tf_col] = flops / df[ms_col] * 1e-9 + return df + def main(): """ Main function to run benchmarks. @@ -1137,27 +1215,30 @@ def main(): # process args to get function configs and input configs function_configs, all_input_configs = process_args() - # Check if we have multiple function configurations - has_multiple_func_configs = len(function_configs) > 1 - combined_df = None - # run benchmarks for each function configuration + combined_ms_df = None + combined_tf_df = None + input_cols = ["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"] for func_config in function_configs: # run benchmark with the input configs for this function config input_configs = all_input_configs[func_config] df = run_benchmark(func_config, input_configs) + df = add_tflops_columns(df, func_config) - # Define the columns that represent input configurations - input_config_cols = ["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"] - - # merge into one final dataframe - if combined_df is None: - combined_df = df + # add to combined table + ms_cols = [c for c in df.columns if c.endswith('_ms')] + tf_cols = [c for c in df.columns if c.endswith('_tflops')] + + ms_df = df[input_cols + ms_cols] + tf_df = df[input_cols + tf_cols] + + if combined_ms_df is None: + combined_ms_df = ms_df + combined_tf_df = tf_df else: - # Ensure we're joining on input configuration columns - combined_df = combined_df.merge(df, on=input_config_cols, how="outer") + combined_ms_df = combined_ms_df.merge(ms_df, on=input_cols, how="outer") + combined_tf_df = combined_tf_df.merge(tf_df, on=input_cols, how="outer") - # print new line to seperate the combined data information from the benchmark specific information print() @@ -1166,6 +1247,7 @@ def main(): print(f"Total time for all benchmarks: {total_elapsed_time:.2f} seconds") # save combined data and make comparisons if we have multiple function configs + has_multiple_func_configs = False # len(function_configs) > 1 if has_multiple_func_configs: if len(function_configs) == 2: func1 = function_configs[0] @@ -1194,30 +1276,25 @@ def main(): ratio_col = f"ck_to_triton_ratio" # Calculate ratio: ck_time / triton_time (values > 1 mean triton is faster) - combined_df[ratio_col] = combined_df[ck_col] / combined_df[triton_col] + combined_ms_df[ratio_col] = combined_ms_df[ck_col] / combined_ms_df[triton_col] # print explanation print(f"Comparison Results (triton vs ck):") print(f"Ratio values: values > 1 mean triton is faster (by that factor), values < 1 mean ck is faster") - elif False: - # For other comparisons, use the standard approach - ratio_col = f"{func1}_to_{func2}_ratio" - - # Calculate the ratio - combined_df[ratio_col] = combined_df[col2] / combined_df[col1] - - # print explanation - print(f"Comparison Results ({func1} vs {func2}):") - print(f"Ratio values: values > 1 mean {func1} is faster than {func2} (by that factor), values < 1 mean slower") - - print(f"Combined data:") - print(combined_df) - - # save csv & markdown - combined_filename = f"benchmark_combined" - combined_df.to_csv(f"{combined_filename}.csv", index=False) - with open(f"{combined_filename}.md", 'w') as f: - f.write(combined_df.to_markdown(index=False, floatfmt=".2f")) + + if combined_ms_df is not None: + print("\nCombined wall‑time (ms) table:") + print(combined_ms_df) + combined_ms_df.to_csv("benchmark_ms.csv", index=False) + with open("benchmark_ms.md", 'w') as f: + f.write(combined_ms_df.to_markdown(index=False, floatfmt=".2f")) + + if combined_tf_df is not None: + print("\nCombined throughput (TFLOPs) table:") + print(combined_tf_df) + combined_tf_df.to_csv("benchmark_tflops.csv", index=False) + with open("benchmark_tflops.md", 'w') as f: + f.write(combined_tf_df.to_markdown(index=False, floatfmt=".2f")) if __name__ == "__main__": main() \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py index 3c018be4fa0..af3f8790026 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py @@ -354,9 +354,9 @@ def _attn_fwd(q_ptr: torch.Tensor, VARLEN: tl.constexpr, ): #calculate offsets - start_m = tl.program_id(0) #seqlen_q - off_q_head = tl.program_id(1) #num_q_heads - off_z = tl.program_id(2) #batch + off_z = tl.program_id(0) #batch + off_q_head = tl.program_id(1) #num_q_heads + start_m = tl.program_id(2) #seqlen_q offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) @@ -730,14 +730,14 @@ def _flash_attn_forward( # Tuned for MI300x config = { 'BLOCK_M': 128, - 'BLOCK_N': 32, # BLOCK_N: 64 spills for _attn_fwd + 'BLOCK_N': 64, 'waves_per_eu': 2, 'num_warps': 4, 'num_ctas': 1, 'num_stages': 1, } - grid = lambda META:(triton.cdiv(seqlen_q, META['BLOCK_M']), num_q_heads, batch) + grid = lambda META:(batch, num_q_heads, triton.cdiv(seqlen_q, META['BLOCK_M'])) _attn_fwd[grid](q, k, v, @@ -1105,6 +1105,7 @@ def _bwd_dkdvdq_inner( dk, dv, Q, k, v, DO, DQ, M, D, sm_scale, stride_q_m, stride_q_k, + stride_dq_m, stride_dq_k, stride_do_m, stride_do_k, stride_dropout_m, stride_dropout_n, stride_deltam, @@ -1132,7 +1133,7 @@ def _bwd_dkdvdq_inner( mask_n = offs_n < seqlen_k qT_ptrs_start = Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k #[BLOCK_D_MODEL_POW2, BLOCK_M] - dq_ptrs_start = DQ + offs_m[:, None] * stride_q_m + offs_k[None,:] * stride_q_k #[BLOCK_M, BLOCK_D_MODEL_POW2] + dq_ptrs_start = DQ + offs_m[:, None] * stride_dq_m + offs_k[None,:] * stride_dq_k #[BLOCK_M, BLOCK_D_MODEL_POW2] do_ptrs_start = DO + offs_m[:, None] * stride_do_m + offs_k[None,: ] * stride_do_k curr_m = start_m @@ -1170,7 +1171,7 @@ def _bwd_dkdvdq_inner( curr_m = start_m + blk_idx * step_m qT_ptrs = qT_ptrs_start + blk_idx * step_m * stride_q_m - dq_ptrs = dq_ptrs_start + blk_idx * step_m * stride_q_m + dq_ptrs = dq_ptrs_start + blk_idx * step_m * stride_dq_m do_ptrs = do_ptrs_start + blk_idx * step_m * stride_do_m offs_m = curr_m + tl.arange(0, BLOCK_M) @@ -1279,6 +1280,7 @@ def _bwd_kernel_dkdvdq_causal( stride_k_b, stride_k_h, stride_k_n, stride_k_k, stride_v_b, stride_v_h, stride_v_n, stride_v_k, stride_dk_b, stride_dk_h, stride_dk_n, stride_dk_k, + stride_dq_b, stride_dq_h, stride_dq_m, stride_dq_k, stride_delta_b, stride_delta_h, stride_delta_m, stride_do_b, stride_do_h, stride_do_m, stride_do_k, stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, @@ -1387,9 +1389,10 @@ def _bwd_kernel_dkdvdq_causal( # offset input and output tensor by batch and Q/K heads adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + adj_dq = batch_idx * stride_dq_b + head_q_idx * stride_dq_h + q_start * stride_dq_m q_ptr_adj = q_ptr + adj_q - dq_ptr_adj = dq_ptr + adj_q + dq_ptr_adj = dq_ptr + adj_dq adj_do = batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m do_ptr_adj = do_ptr + adj_do @@ -1425,12 +1428,13 @@ def _bwd_kernel_dkdvdq_causal( else: descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - # if start_m is negative, the current N-tile has no block on the + # if unaligned start_m is negative, the current N-tile has no block on the # diagonal of causal mask, so everything have no causal mask dk, dv = _bwd_dkdvdq_inner( dk, dv, # output tensors q_ptr_adj, k, v, do_ptr_adj, dq_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors stride_q_m, stride_q_k, # strides for q + stride_dq_m, stride_dq_k, # strides for q stride_do_m, stride_do_k, # strides for o stride_dropout_m, stride_dropout_n, # strides for dropout stride_delta_m, @@ -1446,14 +1450,19 @@ def _bwd_kernel_dkdvdq_causal( FP8_MAX=FP8_MAX, workgroup_id=seq_k_blk_idx, ) + + start_m += num_steps * MASK_BLOCK_M num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) end_m = start_m + num_steps * BLOCK_M + + dk, dv = _bwd_dkdvdq_inner( dk, dv, # output tensors q_ptr_adj, k, v, do_ptr_adj, dq_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors stride_q_m, stride_q_k, # strides for q + stride_dq_m, stride_dq_k, # strides for dq stride_do_m, stride_do_k, # strides for o stride_dropout_m, stride_dropout_n, # strides for dropout stride_delta_m, @@ -1865,6 +1874,7 @@ def _bwd_kernel_dkdvdq_noncausal( stride_kb, stride_kh, stride_kn, stride_kk, stride_vb, stride_vh, stride_vn, stride_vk, stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, stride_deltab, stride_deltah, stride_deltam, stride_dob, stride_doh, stride_dom, stride_dok, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, @@ -1939,9 +1949,10 @@ def _bwd_kernel_dkdvdq_noncausal( for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): adj_q = (bid * stride_qb + hqid * stride_qh + q_start * stride_qm) - + adj_dq = (bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm) + Q_ptr = Q + adj_q - DQ_ptr = DQ + adj_q + DQ_ptr = DQ + adj_dq adj_do = (bid * stride_dob + hqid * stride_doh + q_start * stride_dom) DO_ptr = DO + adj_do @@ -1973,6 +1984,7 @@ def _bwd_kernel_dkdvdq_noncausal( dk, dv, Q_ptr, k, v, DO_ptr, DQ_ptr, M_ptr, Delta_ptr, sm_scale, stride_qm, stride_qk, + stride_dqm, stride_dqk, stride_dom, stride_dok, stride_dropoutm, stride_dropoutn, stride_deltam, @@ -2372,7 +2384,7 @@ def _flash_attn_backward( if fused: # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups - BLOCK_N = 128 + BLOCK_N = 128 if BLOCK_D_MODEL_POW2 < 160 else 64 # larger head sizes lead to oom config = { "BLOCK_M": 32, "BLOCK_N": BLOCK_N, @@ -2393,6 +2405,7 @@ def _flash_attn_backward( *k_strides, *v_strides, *dk_strides, + *dq_strides, *delta_strides, *do_strides, *dropout_strides, @@ -2422,6 +2435,7 @@ def _flash_attn_backward( *k_strides, *v_strides, *dk_strides, + *dq_strides, *delta_strides, *do_strides, *dropout_strides, @@ -2776,6 +2790,7 @@ def forward( return_lse, return_softmax, is_grad_enabled, + fused_backward, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q,k,v] @@ -2812,7 +2827,7 @@ def forward( cu_seqlens_k=None, descale_q=descale_q, descale_k=descale_k, - descale_v=descale_v + descale_v=descale_v, ) if is_grad: @@ -2824,6 +2839,7 @@ def forward( ctx.causal = causal ctx.window_size = window_size ctx.alibi_slopes = alibi_slopes + ctx.fused_backward = fused_backward out = out_padded[..., :head_size_og] result = [out] @@ -2869,11 +2885,12 @@ def backward(ctx, do, *args): descale_k=descale_k, descale_v=descale_v, descale_do=descale_do, + fused=ctx.fused_backward, ) #dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension #dk = dk[..., : k_fp8.shape[-1]] #dv = dv[..., : v_fp8.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None def flash_attn_fp8_func( q, @@ -2886,7 +2903,8 @@ def flash_attn_fp8_func( alibi_slopes=None, deterministic=False, return_lse=False, - return_attn_probs=False + return_attn_probs=False, + fused_backward=False, ): return FlashAttnFP8Func.apply( q, @@ -2900,7 +2918,8 @@ def flash_attn_fp8_func( deterministic, return_lse, return_attn_probs, - torch.is_grad_enabled() + torch.is_grad_enabled(), + fused_backward, ) class FlashAttnVarlenFunc(torch.autograd.Function): @@ -3127,6 +3146,7 @@ def forward( return_softmax, block_table, is_grad_enabled, + fused_backward, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, k, v] @@ -3163,7 +3183,8 @@ def forward( cu_seqlens_k=cu_seqlens_k, descale_q=descale_q, descale_k=descale_k, - descale_v=descale_v + descale_v=descale_v, + fused_backward=fused_backward, ) if is_grad: ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, descale_q, descale_k, descale_v) @@ -3176,6 +3197,7 @@ def forward( ctx.causal = causal ctx.window_size = window_size ctx.alibi_slopes = alibi_slopes + ctx.fused_backward = fused_backward out = out_padded[..., :head_size_og] result = [out] if return_lse: @@ -3187,15 +3209,15 @@ def forward( @staticmethod def backward(ctx, do, *args): - q_fp8, k_fp8, v_fp8, out, softmax_lse, cu_seqlens_q, cu_seqlens_q, descale_q, descale_k, descale_v = ctx.saved_tensors - dq, dk, dv = torch.zeros_like(q, dtype=torch.float32), torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32) + q_fp8, k_fp8, v_fp8, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, descale_q, descale_k, descale_v = ctx.saved_tensors + dq, dk, dv = torch.zeros_like(q_fp8, dtype=torch.float32), torch.zeros_like(k_fp8, dtype=torch.float32), torch.zeros_like(v_fp8, dtype=torch.float32) head_size_v_og = do.size(3) do_padded = do if head_size_v_og % 8 != 0: do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) fp8_dtype = torch.float8_e4m3fnuz - do_padded_fp8, descale_do = cast_varlen_to_fp8(dout_padded, fp8_dtype, "thd", cu_seqlens_q) + do_padded_fp8, descale_do = cast_varlen_to_fp8(do_padded, fp8_dtype, "thd", cu_seqlens_q) _flash_attn_backward( do_padded_fp8, @@ -3212,8 +3234,8 @@ def backward(ctx, do, *args): ctx.causal, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_k=ctx.max_seqlen_k, dropout_p=ctx.dropout_p, philox_seed=ctx.philox_seed, philox_offset=ctx.philox_offset, @@ -3222,10 +3244,10 @@ def backward(ctx, do, *args): descale_v=descale_v, descale_do=descale_do ) - dq = dq[..., : q.shape[-1]] # We could have padded the head dimension - dk = dk[..., : k.shape[-1]] - dv = dv[..., : v.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None + dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k_fp8.shape[-1]] + dv = dv[..., : v_fp8.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None def flash_attn_varlen_fp8_func( q, @@ -3243,7 +3265,8 @@ def flash_attn_varlen_fp8_func( deterministic=False, return_lse=False, return_attn_probs=False, - block_table=None + block_table=None, + fused_backward=False, ): return FlashAttnVarlenFP8Func.apply( q, @@ -3262,5 +3285,6 @@ def flash_attn_varlen_fp8_func( return_lse, return_attn_probs, block_table, - torch.is_grad_enabled() - ) \ No newline at end of file + torch.is_grad_enabled(), + fused_backward, + ) diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py index 3f650d288db..9f8a1ab46a2 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py @@ -2,8 +2,8 @@ import triton # type: ignore import triton.language as tl # type: ignore from typing import Literal, Optional -from .utils import AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, get_shapes_from_layout, compute_fp8_scaling_factors, \ - get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_cdna, is_rdna +from .utils import DEBUG, AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, get_shapes_from_layout, compute_fp8_scaling_factors, \ + get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_cdna, is_fp8, is_rdna # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) @@ -108,7 +108,7 @@ def get_autotune_configs(): def _bwd_preprocess( O, DO, # noqa: E741 Delta, - stride_ob, stride_oh, stride_om, stride_ok, + stride_ob, stride_oh, stride_om, stride_od, stride_deltab, stride_deltah, stride_deltam, stride_descale_do_z, cu_seqlens_q, max_seqlen_q, @@ -135,7 +135,7 @@ def _bwd_preprocess( # Compute offsets offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) - offs_k = tl.arange(0, HEAD_DIM) + offs_d = tl.arange(0, HEAD_DIM) # Offset O/DO by batch, head and q_start O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741 DO += bid * stride_ob + hid * stride_oh + q_start * stride_om @@ -144,9 +144,9 @@ def _bwd_preprocess( mask_md = mask_m[:, None] PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) if PADDED_HEAD: - mask_md &= offs_k[None, :] < ACTUAL_HEAD_DIM + mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM # compute pointers - offs_do = offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok + offs_do = offs_m[:, None] * stride_om + offs_d[None, :] * stride_od out_ptrs = O + offs_do do_ptrs = DO + offs_do # load @@ -171,19 +171,21 @@ def _bwd_dkdv_inner( Q, k, v, DO, M, D, sm_scale, # input tensor stride_qm, stride_qk, stride_dom, stride_dok, - stride_dropoutm, stride_dropoutn, # + stride_dropoutm, stride_dropoutn, stride_deltam, BLOCK_M: tl.constexpr, # 16 BLOCK_N: tl.constexpr, # 128 HEAD_DIM: tl.constexpr, # ACTUAL_HEAD_DIM: tl.constexpr, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k # Filled in by the wrapper. start_n, start_m, num_steps, # iteration numbers descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal ENABLE_DROPOUT: tl.constexpr, # activate dropout + USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, # activate exp2 IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, @@ -246,16 +248,23 @@ def _bwd_dkdv_inner( qkT = (tl.dot(k, qT) * descale_q * descale_k) else: qkT = tl.dot(k, qT) + qkT_scaled = qkT * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_n[:, None] + seqlen_q - seqlen_k - offs_m[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qkT_scaled += alibi_block + if DEBUG_TRITON_DETAIL: if start_n == 256: print(f"qT: {qT.shape}\n", qT) print(f"k: {k.shape}\n", k) - print(f"qkT scaled: {qkT.shape}\n", qkT * sm_scale) + print(f"qkT scaled: {qkT.shape}\n", qkT_scaled) # TODO: remove the scaling of m later when we removed re-scaling in fwd if USE_EXP2: - pT = tl.math.exp2(qkT * sm_scale * RCP_LN2 - m[None, :] * RCP_LN2) + pT = tl.math.exp2(qkT_scaled * RCP_LN2 - m[None, :] * RCP_LN2) else: - pT = tl.math.exp(qkT * sm_scale - m[None, :]) + pT = tl.math.exp(qkT_scaled - m[None, :]) # Autoregressive masking. if MASK: @@ -323,12 +332,14 @@ def _bwd_dq_inner( BLOCK_N2: tl.constexpr, # HEAD_DIM: tl.constexpr, ACTUAL_HEAD_DIM: tl.constexpr, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, # Filled in by the wrapper. start_m, start_n, end_n, num_steps, # descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user MASK: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, + USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, @@ -392,11 +403,18 @@ def _bwd_dq_inner( qk = (tl.dot(q, kT) * descale_q * descale_k) else: qk = tl.dot(q, kT) - if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk * sm_scale) # noqa: E701 + qk_scaled = qk * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qk_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk_scaled) # noqa: E701 if USE_EXP2: - p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) + p = tl.math.exp2(qk_scaled * RCP_LN2 - m * RCP_LN2) else: - p = tl.math.exp(qk * sm_scale - m) + p = tl.math.exp(qk_scaled - m) # Autoregressive masking. if MASK: @@ -434,18 +452,23 @@ def _bwd_dq_inner( def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) Q, K, V, sm_scale, DO, DQ, DK, DV, M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_qb, stride_qh, stride_qm, stride_qd, + stride_kb, stride_kh, stride_kn, stride_kd, + stride_vb, stride_vh, stride_vn, stride_vd, + stride_dqb, stride_dqh, stride_dqm, stride_dqd, + stride_dkb, stride_dkh, stride_dkn, stride_dkd, + stride_dvb, stride_dvh, stride_dvn, stride_dvd, stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, + stride_dob, stride_doh, stride_dom, stride_dod, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, HQ, HK, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, + Dropout_mask, dropout_p, philox_seed, philox_offset_base, + Alibi_slopes, + Descale_q, Descale_k, Descale_v, Descale_do, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, BLOCK_M2: tl.constexpr, @@ -455,7 +478,11 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead ACTUAL_HEAD_DIM: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, ): @@ -481,7 +508,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead delta_qk = seqlen_q - seqlen_k if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701 PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - offs_k = tl.arange(0, HEAD_DIM) + offs_d = tl.arange(0, HEAD_DIM) GROUP_SIZE: tl.constexpr = HQ // HK # align the delta_qk @@ -511,15 +538,15 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead # Mask for loading K and V mask_kv = offs_n[:, None] < seqlen_k if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_kv &= mask_k[None, :] - offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + mask_d = offs_d < ACTUAL_HEAD_DIM + mask_kv &= mask_d[None, :] # K/V tensors not changed for the group - adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) - v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. # hqid = hkid for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): @@ -546,6 +573,12 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead M_ptr = M + adj_delta Delta_ptr = Delta + adj_delta + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + # batch_philox_offset is the ACTUALLY dropout offset # dropout_offset is for debug purpose and will be removed later batch_philox_offset = 0 @@ -553,9 +586,17 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead if ENABLE_DROPOUT: batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ hqid * stride_dropouth - dropout_offset = dropout_mask + bid * stride_dropoutb + \ + dropout_offset = Dropout_mask + bid * stride_dropoutb + \ hqid * stride_dropouth + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR # bound the masked operation to q len so it does not have to wast cycles len_m = min(len_m, seqlen_q) @@ -570,21 +611,23 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead dk, dv = _bwd_dkdv_inner( dk, dv, # output tensors Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o + stride_qm, stride_qd, # strides for q + stride_dom, stride_dod, # strides for o stride_dropoutm, stride_dropoutn, # strides for dropout stride_deltam, MASK_BLOCK_M1, BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k start_n, start_m, num_steps, # iteration numbers - None, None, None, None, + descale_q, descale_k, descale_v, descale_do, MASK=True, # causal masking ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) @@ -598,31 +641,36 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead dk, dv = _bwd_dkdv_inner( dk, dv, # output tensors Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o + stride_qm, stride_qd, # strides for q + stride_dom, stride_dod, # strides for o stride_dropoutm, stride_dropoutn, # strides for dropout stride_deltam, BLOCK_M1, BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k start_n, start_m, num_steps, # iteration numbers - None, None, None, None, + descale_q, descale_k, descale_v, descale_do, MASK=False, # causal masking ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) # end of GQA/MQA of dkdv - # Write back dV and dK. - adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn - offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk - tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + # Write back dV + adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn + offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) + # write back dk + adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn + offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd dk *= sm_scale - tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) # This part does dq start_m = pid * BLOCK_M2 @@ -637,13 +685,14 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead # Mask for loading K and V mask_q = offs_m[:, None] < seqlen_q if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk - offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - K += adj_kv - V += adj_kv + mask_d = offs_d < ACTUAL_HEAD_DIM + mask_q &= mask_d[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + # NOTE: don't assume that the strides for k and v are the same! + K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn + V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn + # If MQA / GQA, set the K and V head offsets appropriately. for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front @@ -659,6 +708,12 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam Delta_ptr = Delta + adj_delta + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + # batch_philox_offset is the ACTUALLY dropout offset # dropout_offset is for debug purpose and will be removed later batch_philox_offset = 0 @@ -668,8 +723,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead bid * stride_dropoutb + \ hqid * stride_dropouth dropout_offset = \ - dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) m = tl.load(M + adj_delta + offs_m * stride_deltam, @@ -680,24 +734,35 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead # start can only be 0 at minimum start_n = max(end_n - BLOCK_M2, 0) num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) dq = _bwd_dq_inner( dq, - q, K, V, do, m, Delta_ptr, sm_scale, # - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, # + q, K, V, do, m, Delta_ptr, sm_scale, + stride_qm, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd, + stride_dropoutm, stride_dropoutn, stride_deltam, - seqlen_q, seqlen_k, # - BLOCK_M2, MASK_BLOCK_N2, # - HEAD_DIM, ACTUAL_HEAD_DIM, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - start_m, start_n, end_n, num_steps, # - None, None, None, None, + seqlen_q, seqlen_k, + BLOCK_M2, MASK_BLOCK_N2, + HEAD_DIM, ACTUAL_HEAD_DIM, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, MASK=True, # ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) @@ -706,28 +771,30 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead start_n = max(end_n - num_steps * BLOCK_N2, 0) if DEBUG_TRITON: print(f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 dq = _bwd_dq_inner( - dq, # - q, K, V, do, m, Delta_ptr, sm_scale, # - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, # - stride_dropoutm, stride_dropoutn, # + dq, + q, K, V, do, m, Delta_ptr, sm_scale, + stride_qm, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd, + stride_dropoutm, stride_dropoutn, stride_deltam, - seqlen_q, seqlen_k, # - BLOCK_M2, BLOCK_N2, # - HEAD_DIM, ACTUAL_HEAD_DIM, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - start_m, start_n, end_n, num_steps, # - None, None, None, None, - MASK=False, # + seqlen_q, seqlen_k, + BLOCK_M2, BLOCK_N2, + HEAD_DIM, ACTUAL_HEAD_DIM, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + MASK=False, ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) # Write back dQ. adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) # end of GQA/MQA of dq @@ -741,18 +808,23 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead def bwd_kernel_noncausal( Q, K, V, sm_scale, DO, DQ, DK, DV, M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_qb, stride_qh, stride_qm, stride_qd, + stride_kb, stride_kh, stride_kn, stride_kd, + stride_vb, stride_vh, stride_vn, stride_vd, + stride_dqb, stride_dqh, stride_dqm, stride_dqd, + stride_dkb, stride_dkh, stride_dkn, stride_dkd, + stride_dvb, stride_dvh, stride_dvn, stride_dvd, stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, + stride_dob, stride_doh, stride_dom, stride_dod, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, HQ, HK, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, + Dropout_mask, dropout_p, philox_seed, philox_offset_base, + Alibi_slopes, + Descale_q, Descale_k, Descale_v, Descale_do, BLOCK_M1: tl.constexpr, # 32 BLOCK_N1: tl.constexpr, # 128 BLOCK_M2: tl.constexpr, # 128 @@ -762,7 +834,11 @@ def bwd_kernel_noncausal( ACTUAL_HEAD_DIM: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, ): @@ -786,7 +862,7 @@ def bwd_kernel_noncausal( seqlen_k = k_end - k_start PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - offs_k = tl.arange(0, HEAD_DIM) + offs_d = tl.arange(0, HEAD_DIM) GROUP_SIZE: tl.constexpr = HQ // HK start_n = pid * BLOCK_N1 @@ -798,15 +874,15 @@ def bwd_kernel_noncausal( # Mask for loading K and V mask_kv = offs_n[:, None] < seqlen_k if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_kv &= mask_k[None, :] - offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk - + mask_d = offs_d < ACTUAL_HEAD_DIM + mask_kv &= mask_d[None, :] + # NOTE: don't assume that the strides for k and v are the same! # K/V tensors not changed for the group - adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) - v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): # offset input and output tensor by batch and Q/K heads @@ -818,6 +894,12 @@ def bwd_kernel_noncausal( M_ptr = M + adj_delta Delta_ptr = Delta + adj_delta + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + # batch_philox_offset is the ACTUALLY dropout offset # dropout_offset is for debug purpose and will be removed later batch_philox_offset = 0 @@ -825,40 +907,53 @@ def bwd_kernel_noncausal( if ENABLE_DROPOUT: batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ hqid * stride_dropouth - dropout_offset = dropout_mask + bid * stride_dropoutb + \ + dropout_offset = Dropout_mask + bid * stride_dropoutb + \ hqid * stride_dropouth + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + # because there is no causal, we always start from the beginning start_m = 0 num_steps = tl.cdiv(seqlen_q, BLOCK_M1) dk, dv = _bwd_dkdv_inner( dk, dv, # output tensors Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o + stride_qm, stride_qd, # strides for q + stride_dom, stride_dod, # strides for o stride_dropoutm, stride_dropoutn, # strides for dropout stride_deltam, BLOCK_M1, BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k start_n, start_m, num_steps, # iteration numbers - None, None, None, None, + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user MASK=False, # causal masking ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) - # Write back dV and dK. - adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn - offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk - tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + # Write back dV + adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn + offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) + # write back dk + adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn + offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd dk *= sm_scale - tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) # THIS PART DOES DQ start_m = pid * BLOCK_M2 @@ -867,13 +962,12 @@ def bwd_kernel_noncausal( # Mask for loading K and V mask_q = offs_m[:, None] < seqlen_q if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk - offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - K += adj_kv - V += adj_kv + mask_d = offs_d < ACTUAL_HEAD_DIM + mask_q &= mask_d[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn + V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn # If MQA / GQA, set the K and V head offsets appropriately. for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): # offset input and output tensor by batch and Q/K heads @@ -883,6 +977,12 @@ def bwd_kernel_noncausal( bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam Delta_ptr = Delta + adj_delta + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + # batch_philox_offset is the ACTUALLY dropout offset # dropout_offset is for debug purpose and will be removed later batch_philox_offset = 0 @@ -892,7 +992,7 @@ def bwd_kernel_noncausal( bid * stride_dropoutb + \ hqid * stride_dropouth dropout_offset = \ - dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) @@ -900,6 +1000,14 @@ def bwd_kernel_noncausal( mask=offs_m < seqlen_q) m = m[:, None] + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + # start can only be 0 at minimum start_n = 0 end_n = seqlen_k @@ -907,31 +1015,39 @@ def bwd_kernel_noncausal( dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) dq = _bwd_dq_inner( - dq, # - q, K, V, do, m, Delta_ptr, sm_scale, # - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, # - stride_dropoutm, stride_dropoutn, # + dq, + q, K, V, do, m, Delta_ptr, sm_scale, + stride_qm, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd, + stride_dropoutm, stride_dropoutn, stride_deltam, - seqlen_q, seqlen_k, # - BLOCK_M2, BLOCK_N2, # - HEAD_DIM, ACTUAL_HEAD_DIM, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - start_m, start_n, end_n, num_steps, # - None, None, None, None, - MASK=False, # + seqlen_q, seqlen_k, + BLOCK_M2, BLOCK_N2, + HEAD_DIM, ACTUAL_HEAD_DIM, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + MASK=False, ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) # Write back dQ. adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) +def is_contiguous(x, name): + if x.is_contiguous(): + return x + else: + print(f"{name} is not contiguous") + return x.contiguous() def attention_prefill_backward_triton_split_oneKernel_impl( do: torch.Tensor, @@ -955,11 +1071,55 @@ def attention_prefill_backward_triton_split_oneKernel_impl( philox_seed: Optional[int], philox_offset: Optional[int], use_exp2: bool, + # fp8 + descale_q: Optional[torch.Tensor], + descale_k: Optional[torch.Tensor], + descale_v: Optional[torch.Tensor], + descale_o: Optional[torch.Tensor], + descale_do: Optional[torch.Tensor], + descale_dq: Optional[torch.Tensor], + descale_dk: Optional[torch.Tensor], + descale_dv: Optional[torch.Tensor], ): # debug DEBUG_TRITON: bool = False DEBUG_TRITON_DETAIL: bool = False + # do = is_contiguous(do, "do") + # q = is_contiguous(q, "q") + # k = is_contiguous(k, "k") + # v = is_contiguous(v, "v") + # o = is_contiguous(o, "o") + # softmax_lse = is_contiguous(softmax_lse, "softmax_lse") + # dq = is_contiguous(dq, "dq") + # dk = is_contiguous(dk, "dk") + # dv = is_contiguous(dv, "dv") + + IS_FP8 = is_fp8(q) + if IS_FP8: + FP8_MAX = torch.finfo(q.dtype).max + # assert that the main inputs are fp8 + assert is_fp8(do) and is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: do.dtype={do.dtype}, q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8." + if is_fp8(o): + FP8_OUTPUT = True + assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." + assert descale_dq is not None, f"descale_dq is None. In fp8, you need to pass a tensor for descale_dq along with a tensor dq." + assert descale_dk is not None, f"descale_dk is None. In fp8, you need to pass a tensor for descale_dk along with a tensor dk." + assert descale_dv is not None, f"descale_dv is None. In fp8, you need to pass a tensor for descale_dv along with a tensor dv." + else: + FP8_OUTPUT = False + + stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None + stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None + stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None + stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None + stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None + else: + FP8_MAX = None + FP8_OUTPUT = False + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = stride_descale_do_z = None + + # get strides and shape batch, nheads_q, nheads_k, head_size, max_seqlen_q_final, max_seqlen_k_final = \ get_shapes_from_layout( @@ -969,21 +1129,23 @@ def attention_prefill_backward_triton_split_oneKernel_impl( ) q_strides, k_strides, v_strides, o_strides = \ get_strides_from_layout(q, k, v, o, layout) - stride_qb, stride_qh, stride_qm, stride_qk = q_strides - stride_kb, stride_kh, stride_kn, stride_kk = k_strides - stride_vb, stride_vh, stride_vn, stride_vk = v_strides - stride_ob, stride_oh, stride_om, stride_ok = o_strides - dq_strides, dk_strides, _, do_strides = \ + stride_qb, stride_qh, stride_qm, stride_qd = q_strides + stride_kb, stride_kh, stride_kn, stride_kd = k_strides + stride_vb, stride_vh, stride_vn, stride_vd = v_strides + stride_ob, stride_oh, stride_om, stride_od = o_strides + dq_strides, dk_strides, dv_strides, do_strides = \ get_strides_from_layout(dq, dk, dv, do, layout) - stride_dqb, stride_dqh, stride_dqm, stride_dqk = dq_strides - stride_dkb, stride_dkh, stride_dkn, stride_dkk = dk_strides - stride_dob, stride_doh, stride_dom, stride_dok = do_strides + stride_dqb, stride_dqh, stride_dqm, stride_dqd = dq_strides + stride_dkb, stride_dkh, stride_dkn, stride_dkd = dk_strides + stride_dvb, stride_dvh, stride_dvn, stride_dvd = dv_strides + stride_dob, stride_doh, stride_dom, stride_dod = do_strides IS_VARLEN = layout == "thd" use_dropout = (dropout_p > 0.0) + use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) # get closest power of 2 over or equal to 32. padded_d_model = 1 << (head_size - 1).bit_length() - padded_d_model = max(padded_d_model, 16) + padded_d_model = max(padded_d_model, 32) HEAD_DIM = padded_d_model ACTUAL_HEAD_DIM = head_size @@ -998,17 +1160,20 @@ def attention_prefill_backward_triton_split_oneKernel_impl( _bwd_preprocess[pre_grid]( o, do, delta, - stride_ob, stride_oh, stride_om, stride_ok, + stride_ob, stride_oh, stride_om, stride_od, stride_deltab, stride_deltah, stride_deltam, - 0, + stride_descale_do_z, cu_seqlens_q, max_seqlen_q_final, - None, + descale_do, HEAD_DIM=HEAD_DIM, ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, IS_VARLEN=IS_VARLEN, - IS_FP8=False + IS_FP8=IS_FP8 ) + if DEBUG: + print("delta:", delta, delta.shape) + # dropout mask tensor for debugging. We dump the dropout mask created in # the kernel for testing dropout_mask = None @@ -1043,23 +1208,32 @@ def attention_prefill_backward_triton_split_oneKernel_impl( bwd_kernel_causal[grid]( q, k, v, sm_scale, do, dq, dk, dv, softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_qb, stride_qh, stride_qm, stride_qd, + stride_kb, stride_kh, stride_kn, stride_kd, + stride_vb, stride_vh, stride_vn, stride_vd, + stride_dqb, stride_dqh, stride_dqm, stride_dqd, + stride_dkb, stride_dkh, stride_dkn, stride_dkd, + stride_dvb, stride_dvh, stride_dvn, stride_dvd, stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, + stride_dob, stride_doh, stride_dom, stride_dod, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q_final, max_seqlen_k_final, dropout_mask, dropout_p, philox_seed, philox_offset, - HEAD_DIM=HEAD_DIM, + alibi_slopes, + descale_q, descale_k, descale_v, descale_do, + HEAD_DIM=HEAD_DIM, ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) @@ -1067,23 +1241,32 @@ def attention_prefill_backward_triton_split_oneKernel_impl( bwd_kernel_noncausal[grid]( q, k, v, sm_scale, do, dq, dk, dv, softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_qb, stride_qh, stride_qm, stride_qd, + stride_kb, stride_kh, stride_kn, stride_kd, + stride_vb, stride_vh, stride_vn, stride_vd, + stride_dqb, stride_dqh, stride_dqm, stride_dqd, + stride_dkb, stride_dkh, stride_dkn, stride_dkd, + stride_dvb, stride_dvh, stride_dvn, stride_dvd, stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, + stride_dob, stride_doh, stride_dom, stride_dod, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q_final, max_seqlen_k_final, dropout_mask, dropout_p, philox_seed, philox_offset, + alibi_slopes, + descale_q, descale_k, descale_v, descale_do, HEAD_DIM=HEAD_DIM, ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) diff --git a/flash_attn/flash_attn_triton_amd/bwd_ref.py b/flash_attn/flash_attn_triton_amd/bwd_ref.py index 90a98ce4fcc..639211a51f6 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/bwd_ref.py @@ -122,7 +122,7 @@ def attention_backward_core_ref_impl( print("dp:", dp, dp.shape) # calculate ds - if False: + if True: delta = torch.sum(o * do, axis=-1).unsqueeze(-1) else: delta = torch.sum(p * dp, axis=-1).unsqueeze(-1) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index dec5673e3e5..08a307e7669 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -2,7 +2,7 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_shapes_from_layout, get_strides_from_layout, is_cdna, is_fp8, is_rdna, create_dropout_mask +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_arch, get_shapes_from_layout, get_strides_from_layout, is_cdna, is_fp8, is_rdna, create_dropout_mask # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) @@ -80,22 +80,24 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE + if USE_ALIBI: + # compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, + global_n_positions) + qk_scaled += alibi_block + if IS_CAUSAL: causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) + if bias_ptrs is not None: bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) qk_scaled += bias - if USE_ALIBI: - # compute the global position of each token within the sequence - global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - global_n_positions = start_n + tl.arange(0, BLOCK_N) - alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, - global_n_positions) - qk_scaled += alibi_block # get max scores so far m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) @@ -213,12 +215,28 @@ def get_autotune_configs(): else: raise ValueError("Unknown Device Type") else: - return [ - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + arch = get_arch() + if arch == "gfx950": + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, num_warps=4, - ), + ) + elif arch == "gfx942" and False: # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + else: + default_config = triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + + return [ + default_config ], [ "IS_CAUSAL", "dropout_p", @@ -250,14 +268,29 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_INFERENCE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr): + IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr, FLIP_GRID: tl.constexpr): # set params ACCUMULATOR_TYPE = tl.float32 # compute offsets - start_m = tl.program_id(0) - off_h_q = tl.program_id(1) - off_z = tl.program_id(2) + if FLIP_GRID: + #NUM_XCDS: tl.constexpr = 8 + off_z = tl.program_id(0) + off_h_q = tl.program_id(1) + start_m = tl.program_id(2) + + #start_m = (tl.cdiv(MAX_SEQLENS_Q, BLOCK_M) - 1) - start_m + + # Remap heads to the same XCD + #pids_per_xcd = HQ // NUM_XCDS + #xcd_group = off_h_q % NUM_XCDS + #pid_in_xcd = off_h_q // NUM_XCDS + #off_h_q = xcd_group * pids_per_xcd + pid_in_xcd + else: + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) @@ -308,7 +341,7 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) - o_ptrs_mask = offs_m[:, None] < seqlen_q + o_ptrs_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) # We still need to write 0s to the result tl.store(o_ptrs, acc, mask=o_ptrs_mask) # The tensor allocated for L is based on MAX_SEQLENS_Q as that is @@ -598,7 +631,11 @@ def attention_prefill_forward_triton_impl( # kernel is padded - there is no padding in memory for any dims. padded_d_model = max(padded_d_model, 16) - grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) + FLIP_GRID = True + if FLIP_GRID: + grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META['BLOCK_M'])) + else: + grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according @@ -643,6 +680,6 @@ def attention_prefill_forward_triton_impl( MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, IS_VARLEN=is_varlen, IS_INFERENCE=is_inference, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p - > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT) + > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT, FLIP_GRID=FLIP_GRID) return softmax_lse, sd_mask if return_softmax else None diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index bb6e25b509c..a92b6f5d65d 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -13,6 +13,10 @@ from flash_attn.layers.rotary import apply_rotary_emb from typing import Literal, Optional, Union + +USE_EXP2 = True +BWD_MODE = os.environ.get('BWD_MODE', 'jingning').lower() + def fwd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -66,12 +70,19 @@ def fwd(q: torch.Tensor, if return_softmax: metadata.return_scores = True - batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, metadata.layout) + # get shape + batch, _ , nheads_q, _= q.shape if causal: metadata.need_causal(True) if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError(f"Alibi can be (nheads,) or (batch_size, nheads). Given tensor with shape {alibi_slopes.shape}") metadata.need_alibi(alibi_slopes, batch, nheads_q) if dropout_p > 0.0: @@ -103,7 +114,7 @@ def fwd(q: torch.Tensor, metadata.dropout_p, metadata.philox_seed, metadata.philox_offset, - metadata.use_exp2) + USE_EXP2) softmax_lse=softmax_lse_ref sd_mask=sd_mask_ref else: @@ -129,7 +140,7 @@ def fwd(q: torch.Tensor, metadata.philox_seed, metadata.philox_offset, metadata.return_scores, - metadata.use_exp2, + USE_EXP2, descale_q, descale_k, descale_v, @@ -147,7 +158,6 @@ def fwd(q: torch.Tensor, return out, softmax_lse, sd_mask, rng_state -BWD_MODE = os.environ.get('BWD_MODE', 'split').lower() def bwd( dout: torch.Tensor, q: torch.Tensor, @@ -212,12 +222,23 @@ def bwd( dk = torch.zeros_like(k) if dk is None else dk.zero_() dv = torch.zeros_like(v) if dv is None else dv.zero_() + # get shape + batch, _ , nheads_q, _= q.shape + if dropout_p > 0.0: assert rng_state is not None philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() else: philox_seed, philox_offset = None, None + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") + # call implementation if USE_REF: if DEBUG: @@ -244,7 +265,7 @@ def bwd( dropout_p, philox_seed, philox_offset, - False, + USE_EXP2, ) delta = delta_ref else: @@ -272,7 +293,7 @@ def bwd( dropout_p, philox_seed, philox_offset, - False, + USE_EXP2, descale_q, descale_k, descale_v, @@ -333,7 +354,15 @@ def bwd( dropout_p, philox_seed, philox_offset, - False + USE_EXP2, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, ) delta = delta_triton else: @@ -414,13 +443,20 @@ def varlen_fwd( metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # set layout to "thd" and other metdata assert metadata.layout is not None - # get shapes - batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shapes_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) + # get shape + batch = len(cu_seqlens_q) - 1 + _, nheads_q, _= q.shape if causal: metadata.need_causal(True) if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") metadata.need_alibi(alibi_slopes, batch, nheads_q) if dropout_p > 0.0: @@ -452,7 +488,7 @@ def varlen_fwd( metadata.dropout_p, metadata.philox_seed, metadata.philox_offset, - metadata.use_exp2) + USE_EXP2) softmax_lse=softmax_lse_ref sd_mask=sd_mask_ref else: @@ -478,7 +514,7 @@ def varlen_fwd( metadata.philox_seed, metadata.philox_offset, metadata.return_scores, - metadata.use_exp2, + USE_EXP2, descale_q, descale_k, descale_v, @@ -563,12 +599,24 @@ def varlen_bwd( dk = torch.zeros_like(k) if dk is None else dk.zero_() dv = torch.zeros_like(v) if dv is None else dv.zero_() + # get shape + batch = len(cu_seqlens_q) - 1 + _, nheads_q, _= q.shape + if dropout_p > 0.0: assert rng_state is not None philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() else: philox_seed, philox_offset = None, None + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") + # call implementation if USE_REF: if DEBUG: @@ -594,44 +642,108 @@ def varlen_bwd( dropout_p, philox_seed, philox_offset, - False, + USE_EXP2, ) delta = delta_ref else: if DEBUG: print("Using Triton implementation") - delta_triton = attention_prefill_backward_triton_split_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - "thd", - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - False, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - delta = delta_triton + if BWD_MODE == "split": + delta_triton = attention_prefill_backward_triton_split_impl( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + "thd", + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + USE_EXP2, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, + ) + delta = delta_triton + elif BWD_MODE == "fused": + delta_triton = attention_prefill_backward_triton_fused_impl( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + descale_o, + True, + ) + delta = delta_triton + elif BWD_MODE == "jingning": + delta_triton = attention_prefill_backward_triton_split_oneKernel_impl( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + "thd", + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + USE_EXP2, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, + ) + delta = delta_triton + else: + raise ValueError(f"Unknown bwd mode {BWD_MODE}") if DEBUG: print("varlen_bwd outputs") @@ -703,11 +815,19 @@ def fwd_kvcache( k_new = k v_new = v + # get shape + batch, _ , nheads_q, _= q.shape + if causal: metadata.need_causal(True) if alibi_slopes is not None: - batch, _ , nheads_q, _= q.shape + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") metadata.need_alibi(alibi_slopes, batch, nheads_q) # rotary boolean @@ -785,7 +905,7 @@ def fwd_kvcache( metadata.philox_seed, metadata.philox_offset, metadata.return_scores, - metadata.use_exp2, + USE_EXP2, None, None, None, diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index 58e2ae5fc7f..ea82de065b5 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -23,7 +23,7 @@ from .utils import DEBUG, input_helper, arch_supports_fp8 from .fwd_ref import attention_forward_pytorch_ref_impl from .fwd_prefill import attention_prefill_forward_triton_impl -from .bwd_prefill_split import attention_prefill_backward_triton_split_impl +from .bwd_prefill_onekernel import attention_prefill_backward_triton_split_oneKernel_impl from .bwd_ref import attention_backward_pytorch_ref_impl # set print options @@ -83,6 +83,7 @@ @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('use_exp2', [True, False]) # works when use_exp2 is false @pytest.mark.parametrize('DEBUG_INPUT', [False]) # NOTE: debug input can overflow when the tensors are large. Just use to figure out issues +@pytest.mark.skip() def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, alibi_slopes, layout, dtype, use_exp2, DEBUG_INPUT): torch.manual_seed(42) device = "cuda" @@ -96,7 +97,6 @@ def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dr print("MHA") # update metadata - metadata.use_exp2 = use_exp2 if causal: metadata.need_causal(True) @@ -129,7 +129,7 @@ def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dr metadata.philox_seed, metadata.philox_offset, metadata.return_scores, - metadata.use_exp2, + use_exp2, None, None, None, @@ -259,6 +259,7 @@ def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dr @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('use_exp2', [False]) # FIXME: using exp2 causes issue when used with causal @pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans on larger tensors +@pytest.mark.skip() def test_op_prefill_bwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, alibi_slopes, layout, dtype, use_exp2, DEBUG_INPUT): torch.manual_seed(20) device="cuda" @@ -333,7 +334,7 @@ def test_op_prefill_bwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dr dq_triton = torch.zeros_like(q_triton, dtype=q.dtype) # NOTE: the kernel does inplace accumlation on dq so dq has to be zeros dk_triton = torch.zeros_like(k_triton, dtype=k.dtype) if DEBUG_INPUT else torch.empty_like(k_triton, dtype=k.dtype) dv_triton = torch.zeros_like(v_triton, dtype=v.dtype) if DEBUG_INPUT else torch.empty_like(v_triton, dtype=v.dtype) - delta_triton = attention_prefill_backward_triton_split_impl( + delta_triton = attention_prefill_backward_triton_split_oneKernel_impl( do_triton, q_triton, k_triton, diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 0300e3902a1..cc4f7fa624c 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -12,6 +12,8 @@ # Gloabl Variables # ------------------------------- AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes') +if AUTOTUNE: + os.environ["TRITON_PRINT_AUTOTUNING"] = "1" DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes') USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes') @@ -48,7 +50,6 @@ class MetaData(): philox_seed: Optional[int] = None philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. - use_exp2: bool = False rotary_sin: Optional[torch.Tensor] = None rotary_cos: Optional[torch.Tensor] = None rotary_interleaved: bool = False diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index b5e026803c2..6073cb1c35a 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -16,7 +16,7 @@ from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb -from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_rdna +from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_hip, is_rdna MAX_HEADDIM_SM8x = 192 @@ -26,6 +26,8 @@ is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0) is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0) +skip_bfloat16 = True if is_sm75 or is_hip() else False + def attn_bias_from_alibi_slopes( slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None @@ -565,7 +567,7 @@ def get_dropout_fraction( return dropped.sum() / valid.sum() -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [False]) @@ -714,7 +716,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [True]) @@ -862,9 +864,9 @@ def test_flash_attn_varlen_qkvpacked( assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() -@pytest.mark.parametrize("kvpacked", [False]) +@pytest.mark.parametrize("kvpacked", [True, False]) # @pytest.mark.parametrize("kvpacked", [False]) -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @@ -1139,7 +1141,7 @@ def test_flash_attn_output( @pytest.mark.parametrize("kvpacked", [False]) # @pytest.mark.parametrize('kvpacked', [False]) -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize('mha_type', ["mqa"]) @@ -1459,7 +1461,7 @@ def test_flash_attn_varlen_output( assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @@ -1572,7 +1574,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @@ -1741,7 +1743,7 @@ def test_flash_attn_varlen_causal( assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [True]) @@ -1871,7 +1873,7 @@ def test_flash_attn_splitkv( assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("num_splits", [1, 0]) # @pytest.mark.parametrize("num_splits", [1]) @@ -2183,7 +2185,7 @@ def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [True]) @@ -2310,7 +2312,7 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): ).abs().max().item() + 1e-3 -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @@ -2400,7 +2402,7 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): assert not v.grad.isnan().any() -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @@ -2459,7 +2461,7 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc assert torch.equal(dq, dq0) -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True])