From 36ad14632c3e23c4ae7d342abae11b4f727b9e48 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Sat, 3 May 2025 08:21:47 -0400 Subject: [PATCH 01/27] Fused Bwd (#137) * Fused with Good perf and stride fixed Fix fused bugs isolate failing case fix bug bring back test cases rm split impl in fused use exp2 is global variable now try oom fix save make fused the default limit to reproduce failure return default to split fix head size bug use exp2 back to true * new grid * BLK_SLICE_FACTOR = 1 * add tflops * new commit * test in parrallel * strides added by jusson * disable alibi * fix bugs again * default to fused * add bwd options for varlen * backend filter * default to jingning and batch 4 * best fwd config * fix TRITON_PRINT_AUTOTUNING flag bug * tune * Tuning fwd prefill * add if else * use flag * Minor mask fix * FLIP GRID * use best config for default * print when autotuning * test bfloat16 * fix k and v stride bugs * skip bfloat16 * test kvpacked * disable internal tests * pick default config based on arch * Add alibi in the new bwd kernel (#139) * enable alibi for jinging kernel enable alibi for jinging kernel match * save bad configs * fix alibi and causal bug * disable autotune by default * auto tune when benching is good * set best config * remove env var * Update amd_tests.yml * upgrad to triton==3.3.0 * increase shm * use 64 x 64 for now * save * handle 1d alibi * Add fp8 to fused kernel (#140) * fp8 stuff find test case compute delta fp8 basic fp8 config passing non causal path works * isolate bad case * fix fp8 bug * didnot fix fp8 bug * back to failing test * fp8 tests passing * skip * skip ref tests --------- Co-authored-by: Aliasger Zaidy --- .github/workflows/amd_nightly.yml | 105 ++++ .github/workflows/amd_tests.yml | 63 +++ README.md | 6 + flash_attn/flash_attn_triton_amd/Dockerfile | 2 +- flash_attn/flash_attn_triton_amd/README.md | 4 +- flash_attn/flash_attn_triton_amd/bench.py | 155 ++++-- .../bwd_prefill_fused.py | 82 +-- .../bwd_prefill_onekernel.py | 493 ++++++++++++------ flash_attn/flash_attn_triton_amd/bwd_ref.py | 2 +- .../flash_attn_triton_amd/fwd_prefill.py | 75 ++- .../flash_attn_triton_amd/interface_fa.py | 218 ++++++-- flash_attn/flash_attn_triton_amd/test.py | 9 +- flash_attn/flash_attn_triton_amd/utils.py | 3 +- tests/test_flash_attn_triton_amd.py | 30 +- 14 files changed, 934 insertions(+), 313 deletions(-) create mode 100644 .github/workflows/amd_nightly.yml create mode 100644 .github/workflows/amd_tests.yml diff --git a/.github/workflows/amd_nightly.yml b/.github/workflows/amd_nightly.yml new file mode 100644 index 00000000000..3131496ac49 --- /dev/null +++ b/.github/workflows/amd_nightly.yml @@ -0,0 +1,105 @@ +name: AMD Nightly Kernel Tests + +on: + workflow_dispatch: + push: + branches: [main_perf] + schedule: + - cron: '0 0 * * *' # runs nightly at midnight UTC + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + Nightly-CDNA-AMD: + runs-on: ${{ matrix.runner }} + strategy: + matrix: + runner: [linux-mi300-gpu-1] + fail-fast: false # disables failing the entire job when one matrix entry fails + timeout-minutes: 720 # self hosted runners can run jobs for longer than the default of 360 minutes + container: + image: rocm/pytorch:latest + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --shm-size 16G --group-add video --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Show Device Info + run: | + rocminfo | grep gfx + + - name: Uninstall Triton + run: | + pip uninstall -y triton + rm -rf ~/.triton + rm -rf ./triton/python/build + + - name: Install Triton + run: | + pip install triton==3.3.0 + + - name: Show Triton version + run: | + pip show triton + + - name: Build + run: | + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install + + - name: Install dependencies for bench and misc + run: | + pip install matplotlib pandas tabulate + + - name: AMD Internal Tests + run: | + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest flash_attn/flash_attn_triton_amd/test.py + + - name: Flash Attention Tests + run: | + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest -n 8 tests/test_flash_attn_triton_amd.py + + - name: AMD Bench + run: | + python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_func flash_attn_varlen_func flash_attn_with_kvcache + + Nightly-RDNA-AMD: + runs-on: ${{ matrix.runner }} + strategy: + matrix: + runner: [gfx1100] + fail-fast: false # disables failing the entire job when one matrix entry fails + timeout-minutes: 720 # self hosted runners can run jobs for longer than the default of 360 minutes + container: + image: rocm/pytorch:latest + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Show Device Info + run: | + rocminfo | grep gfx + + - name: Uninstall Triton + run: | + pip uninstall -y triton + rm -rf ~/.triton + rm -rf ./triton/python/build + + - name: Install Triton + run: | + pip install triton==3.3.0 + + - name: Show Triton version + run: | + pip show triton + + - name: Build + run: | + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install + + - name: Flash Attention Tests + run: | + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest tests/test_flash_attn_triton_amd.py::test_flash_attn_output diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml new file mode 100644 index 00000000000..2f49567f960 --- /dev/null +++ b/.github/workflows/amd_tests.yml @@ -0,0 +1,63 @@ +name: AMD Perf Kernel Tests + +on: + workflow_dispatch: + pull_request: + branches: [main_perf] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + Integration-Tests-AMD: + runs-on: ${{ matrix.runner }} + strategy: + matrix: + runner: [linux-mi300-gpu-1] + fail-fast: false # disables failing the entire job when one matrix entry fails + timeout-minutes: 720 # self hosted runners can run jobs for longer than the default of 360 minutes + container: + image: rocm/pytorch:latest + options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --shm-size 16G --group-add video --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Show Device Info + run: | + rocminfo | grep gfx + + - name: Uninstall Triton + run: | + pip uninstall -y triton + rm -rf ~/.triton + rm -rf ./triton/python/build + + - name: Install Triton + run: | + pip install triton==3.3.0 + + - name: Show Triton version + run: | + pip show triton + + - name: Build + run: | + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install + + - name: Install dependencies for bench and misc + run: | + pip install matplotlib pandas tabulate + + - name: AMD Internal Tests + run: | + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest flash_attn/flash_attn_triton_amd/test.py + + - name: Flash Attention Tests + run: | + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest -n 8 tests/test_flash_attn_triton_amd.py + + - name: AMD Bench + run: | + python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_func flash_attn_varlen_func flash_attn_with_kvcache diff --git a/README.md b/README.md index b7e02867095..65a4154da4a 100755 --- a/README.md +++ b/README.md @@ -153,6 +153,9 @@ To get started with the triton backend for AMD, follow the steps below. First install the torch for ROCm from https://pytorch.org/get-started/locally/ if it is not installed. The torch and triton will be installed. +``` +pip install triton==3.3.0 +``` Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. ``` @@ -177,6 +180,9 @@ FROM rocm/pytorch:latest WORKDIR /workspace +# install triton +RUN pip install triton==3.3.0 + # install flash attention ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" diff --git a/flash_attn/flash_attn_triton_amd/Dockerfile b/flash_attn/flash_attn_triton_amd/Dockerfile index 29a2c0c43ec..8df939a0886 100644 --- a/flash_attn/flash_attn_triton_amd/Dockerfile +++ b/flash_attn/flash_attn_triton_amd/Dockerfile @@ -3,7 +3,7 @@ FROM rocm/pytorch:latest WORKDIR /workspace # install triton -RUN pip install triton==3.2.0 +RUN pip install triton==3.3.0 # install flash attention ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" diff --git a/flash_attn/flash_attn_triton_amd/README.md b/flash_attn/flash_attn_triton_amd/README.md index 2d8fd8e70f3..f3a5db67fc5 100644 --- a/flash_attn/flash_attn_triton_amd/README.md +++ b/flash_attn/flash_attn_triton_amd/README.md @@ -28,7 +28,7 @@ To get started with the triton backend for AMD, follow the steps below. First install the recommended Triton version ``` -pip install triton==3.2.0 +pip install triton==3.3.0 ``` Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. @@ -56,7 +56,7 @@ FROM rocm/pytorch:latest WORKDIR /workspace # install triton -RUN pip install triton==3.2.0 +RUN pip install triton==3.3.0 # install flash attention ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" diff --git a/flash_attn/flash_attn_triton_amd/bench.py b/flash_attn/flash_attn_triton_amd/bench.py index 05e64c349be..f2b2e7d11d6 100755 --- a/flash_attn/flash_attn_triton_amd/bench.py +++ b/flash_attn/flash_attn_triton_amd/bench.py @@ -80,7 +80,7 @@ class EnvVariableConfig: backend: Optional[Literal["triton", "ck"]] = None ENV_VARIABLE_CONFIGS : List[EnvVariableConfig] = [ - EnvVariableConfig(key="BWD_MODE", values=["split", "fused", "jingning"], backend="triton"), + # EnvVariableConfig(key="BWD_MODE", values=["split", "fused", "jingning"], backend="triton"), ] class FunctionConfig: @@ -871,8 +871,8 @@ def load_flash_attn_module(backend: Literal["triton", "ck"], env_configs: Dict = # set environment variable for the desired backend if backend == "triton": os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE" - os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "0" os.environ["FLASH_ATTENTION_TRITON_AMD_DEBUG"] = "0" + os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "1" elif backend == "ck": os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "FALSE" else: @@ -1016,15 +1016,30 @@ def get_input_config_set(config_type): # batch, hq, hk, sq, sk, d_head, causal, dropout input_configs = [ # LLaMA 3 8B - (1, 32, 8, 8192, 8192, 128, True, 0.0), + (4, 32, 8, 8192, 8192, 128, True, 0.0), # LLaMA 3 70B - (1, 64, 8, 8192, 8192, 128, True, 0.0), + (4, 64, 8, 8192, 8192, 128, True, 0.0), ] else: raise ValueError(f"Unknown input config: {config_type}") return input_configs +def filter_backends(requested_backends, supported_backends, fn_name): + if requested_backends: + selected = [] + for be in requested_backends: + if be in supported_backends: + selected.append(be) + else: + warning( + f"backend '{be}' requested but not supported by " + f"function '{fn_name}'. skipping this back-end." + ) + return selected + else: + return supported_backends + def process_args(): """ @@ -1052,6 +1067,14 @@ def process_args(): default=None, help=f"Benchmarking mode(s) to run. If omitted, runs all supported modes for each function.", ) + parser.add_argument( + "--backend", + type=str, + nargs='*', + choices=["triton", "ck"], + default=None, + help="Back-end(s) to run (triton, ck). Omit to run every back-end that is both available and supported by the function.", + ) # config parser.add_argument("-b", type=int, default=None, help="Batch size") parser.add_argument("-hq", type=int, default=None, help="Q Number of heads") @@ -1067,7 +1090,8 @@ def process_args(): # parse function args benchmark_fns = args.benchmark_fn - requested_modes = args.mode + requested_modes = args.mode + requested_backends = args.backend # fenerate function configurations and input configurations separately all_function_configs = [] @@ -1101,9 +1125,17 @@ def process_args(): if not modes_to_run: warning(f"No valid modes to run for function '{fn_name}' based on request and function support. Skipping this function.") continue + + # filter by backend + backends_to_run = filter_backends(requested_backends, + supported_backends, + fn_name) + if not backends_to_run: + warning(f"no valid back-ends left for '{fn_name}'. skipping.") + continue # create a function config for each backend and dtype combination - for backend in supported_backends: + for backend in backends_to_run: for dtype in supported_dtypes: for mode in modes_to_run: for env_config in supported_env_configs[backend]: @@ -1124,6 +1156,52 @@ def check_environment_variables(): if key in os.environ: raise ValueError(f"Running with {key} environment variable is not recommended for the benching script. Use --help to see how to use the benching script.") +def compute_flops(batch, hq, hk, sq, sk, d_head, causal): + # 2 FLOPs per multiply‑add + if causal: + valid_pairs = ((sk * (sk + 1)) // 2 if sq > sk else + sq * sk - (sq * (sq - 1)) // 2) + else: + valid_pairs = sq * sk + return 2 * batch * hq * valid_pairs * d_head + +# see ref, https://github.com/ROCm/aiter/blob/jukorhon/mha-bwd/op_benchmarks/triton/bench_mha.py +def _flops_single_row(row: pd.Series, mode: str) -> float: + b, hq, d_head = int(row["BATCH"]), int(row["HQ"]), int(row["D_HEAD"]) + sq, sk = int(row["N_CTX_Q"]), int(row["N_CTX_K"]) + causal = bool(row["CAUSAL"]) + + # -------- number of (query, key) products per head ---------------- + if not causal: + valid_pairs = sq * sk + else: # triangular mask + if sq > sk: + valid_pairs = sk * (sk + 1) // 2 + (sq - sk) * sk + else: # sq <= sk + valid_pairs = sq * (sq + 1) // 2 + + # one matmul FLOPs (mul + add) = 2 · m · n · k + flops_per_matmul = 2.0 * b * hq * valid_pairs * d_head + total_flops = 2.0 * flops_per_matmul # 2 matmuls in forward + + if mode == "fwd": + pass + elif mode == "bwd": + total_flops *= 2.5 # 2·bwd + 0.5·recompute + elif mode == "full": + total_flops *= 3.5 # fwd + bwd + else: + raise ValueError(f"unknown mode {mode}") + + return total_flops + +def add_tflops_columns(df: pd.DataFrame, func_cfg: FunctionConfig) -> pd.DataFrame: + ms_col = func_cfg.column_name() + tf_col = ms_col.replace("_ms", "_tflops") + flops = df.apply(_flops_single_row, axis=1, mode=func_cfg.mode) + df[tf_col] = flops / df[ms_col] * 1e-9 + return df + def main(): """ Main function to run benchmarks. @@ -1137,27 +1215,30 @@ def main(): # process args to get function configs and input configs function_configs, all_input_configs = process_args() - # Check if we have multiple function configurations - has_multiple_func_configs = len(function_configs) > 1 - combined_df = None - # run benchmarks for each function configuration + combined_ms_df = None + combined_tf_df = None + input_cols = ["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"] for func_config in function_configs: # run benchmark with the input configs for this function config input_configs = all_input_configs[func_config] df = run_benchmark(func_config, input_configs) + df = add_tflops_columns(df, func_config) - # Define the columns that represent input configurations - input_config_cols = ["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"] - - # merge into one final dataframe - if combined_df is None: - combined_df = df + # add to combined table + ms_cols = [c for c in df.columns if c.endswith('_ms')] + tf_cols = [c for c in df.columns if c.endswith('_tflops')] + + ms_df = df[input_cols + ms_cols] + tf_df = df[input_cols + tf_cols] + + if combined_ms_df is None: + combined_ms_df = ms_df + combined_tf_df = tf_df else: - # Ensure we're joining on input configuration columns - combined_df = combined_df.merge(df, on=input_config_cols, how="outer") + combined_ms_df = combined_ms_df.merge(ms_df, on=input_cols, how="outer") + combined_tf_df = combined_tf_df.merge(tf_df, on=input_cols, how="outer") - # print new line to seperate the combined data information from the benchmark specific information print() @@ -1166,6 +1247,7 @@ def main(): print(f"Total time for all benchmarks: {total_elapsed_time:.2f} seconds") # save combined data and make comparisons if we have multiple function configs + has_multiple_func_configs = False # len(function_configs) > 1 if has_multiple_func_configs: if len(function_configs) == 2: func1 = function_configs[0] @@ -1194,30 +1276,25 @@ def main(): ratio_col = f"ck_to_triton_ratio" # Calculate ratio: ck_time / triton_time (values > 1 mean triton is faster) - combined_df[ratio_col] = combined_df[ck_col] / combined_df[triton_col] + combined_ms_df[ratio_col] = combined_ms_df[ck_col] / combined_ms_df[triton_col] # print explanation print(f"Comparison Results (triton vs ck):") print(f"Ratio values: values > 1 mean triton is faster (by that factor), values < 1 mean ck is faster") - elif False: - # For other comparisons, use the standard approach - ratio_col = f"{func1}_to_{func2}_ratio" - - # Calculate the ratio - combined_df[ratio_col] = combined_df[col2] / combined_df[col1] - - # print explanation - print(f"Comparison Results ({func1} vs {func2}):") - print(f"Ratio values: values > 1 mean {func1} is faster than {func2} (by that factor), values < 1 mean slower") - - print(f"Combined data:") - print(combined_df) - - # save csv & markdown - combined_filename = f"benchmark_combined" - combined_df.to_csv(f"{combined_filename}.csv", index=False) - with open(f"{combined_filename}.md", 'w') as f: - f.write(combined_df.to_markdown(index=False, floatfmt=".2f")) + + if combined_ms_df is not None: + print("\nCombined wall‑time (ms) table:") + print(combined_ms_df) + combined_ms_df.to_csv("benchmark_ms.csv", index=False) + with open("benchmark_ms.md", 'w') as f: + f.write(combined_ms_df.to_markdown(index=False, floatfmt=".2f")) + + if combined_tf_df is not None: + print("\nCombined throughput (TFLOPs) table:") + print(combined_tf_df) + combined_tf_df.to_csv("benchmark_tflops.csv", index=False) + with open("benchmark_tflops.md", 'w') as f: + f.write(combined_tf_df.to_markdown(index=False, floatfmt=".2f")) if __name__ == "__main__": main() \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py index 3c018be4fa0..af3f8790026 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py @@ -354,9 +354,9 @@ def _attn_fwd(q_ptr: torch.Tensor, VARLEN: tl.constexpr, ): #calculate offsets - start_m = tl.program_id(0) #seqlen_q - off_q_head = tl.program_id(1) #num_q_heads - off_z = tl.program_id(2) #batch + off_z = tl.program_id(0) #batch + off_q_head = tl.program_id(1) #num_q_heads + start_m = tl.program_id(2) #seqlen_q offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) @@ -730,14 +730,14 @@ def _flash_attn_forward( # Tuned for MI300x config = { 'BLOCK_M': 128, - 'BLOCK_N': 32, # BLOCK_N: 64 spills for _attn_fwd + 'BLOCK_N': 64, 'waves_per_eu': 2, 'num_warps': 4, 'num_ctas': 1, 'num_stages': 1, } - grid = lambda META:(triton.cdiv(seqlen_q, META['BLOCK_M']), num_q_heads, batch) + grid = lambda META:(batch, num_q_heads, triton.cdiv(seqlen_q, META['BLOCK_M'])) _attn_fwd[grid](q, k, v, @@ -1105,6 +1105,7 @@ def _bwd_dkdvdq_inner( dk, dv, Q, k, v, DO, DQ, M, D, sm_scale, stride_q_m, stride_q_k, + stride_dq_m, stride_dq_k, stride_do_m, stride_do_k, stride_dropout_m, stride_dropout_n, stride_deltam, @@ -1132,7 +1133,7 @@ def _bwd_dkdvdq_inner( mask_n = offs_n < seqlen_k qT_ptrs_start = Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k #[BLOCK_D_MODEL_POW2, BLOCK_M] - dq_ptrs_start = DQ + offs_m[:, None] * stride_q_m + offs_k[None,:] * stride_q_k #[BLOCK_M, BLOCK_D_MODEL_POW2] + dq_ptrs_start = DQ + offs_m[:, None] * stride_dq_m + offs_k[None,:] * stride_dq_k #[BLOCK_M, BLOCK_D_MODEL_POW2] do_ptrs_start = DO + offs_m[:, None] * stride_do_m + offs_k[None,: ] * stride_do_k curr_m = start_m @@ -1170,7 +1171,7 @@ def _bwd_dkdvdq_inner( curr_m = start_m + blk_idx * step_m qT_ptrs = qT_ptrs_start + blk_idx * step_m * stride_q_m - dq_ptrs = dq_ptrs_start + blk_idx * step_m * stride_q_m + dq_ptrs = dq_ptrs_start + blk_idx * step_m * stride_dq_m do_ptrs = do_ptrs_start + blk_idx * step_m * stride_do_m offs_m = curr_m + tl.arange(0, BLOCK_M) @@ -1279,6 +1280,7 @@ def _bwd_kernel_dkdvdq_causal( stride_k_b, stride_k_h, stride_k_n, stride_k_k, stride_v_b, stride_v_h, stride_v_n, stride_v_k, stride_dk_b, stride_dk_h, stride_dk_n, stride_dk_k, + stride_dq_b, stride_dq_h, stride_dq_m, stride_dq_k, stride_delta_b, stride_delta_h, stride_delta_m, stride_do_b, stride_do_h, stride_do_m, stride_do_k, stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, @@ -1387,9 +1389,10 @@ def _bwd_kernel_dkdvdq_causal( # offset input and output tensor by batch and Q/K heads adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + adj_dq = batch_idx * stride_dq_b + head_q_idx * stride_dq_h + q_start * stride_dq_m q_ptr_adj = q_ptr + adj_q - dq_ptr_adj = dq_ptr + adj_q + dq_ptr_adj = dq_ptr + adj_dq adj_do = batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m do_ptr_adj = do_ptr + adj_do @@ -1425,12 +1428,13 @@ def _bwd_kernel_dkdvdq_causal( else: descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - # if start_m is negative, the current N-tile has no block on the + # if unaligned start_m is negative, the current N-tile has no block on the # diagonal of causal mask, so everything have no causal mask dk, dv = _bwd_dkdvdq_inner( dk, dv, # output tensors q_ptr_adj, k, v, do_ptr_adj, dq_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors stride_q_m, stride_q_k, # strides for q + stride_dq_m, stride_dq_k, # strides for q stride_do_m, stride_do_k, # strides for o stride_dropout_m, stride_dropout_n, # strides for dropout stride_delta_m, @@ -1446,14 +1450,19 @@ def _bwd_kernel_dkdvdq_causal( FP8_MAX=FP8_MAX, workgroup_id=seq_k_blk_idx, ) + + start_m += num_steps * MASK_BLOCK_M num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) end_m = start_m + num_steps * BLOCK_M + + dk, dv = _bwd_dkdvdq_inner( dk, dv, # output tensors q_ptr_adj, k, v, do_ptr_adj, dq_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors stride_q_m, stride_q_k, # strides for q + stride_dq_m, stride_dq_k, # strides for dq stride_do_m, stride_do_k, # strides for o stride_dropout_m, stride_dropout_n, # strides for dropout stride_delta_m, @@ -1865,6 +1874,7 @@ def _bwd_kernel_dkdvdq_noncausal( stride_kb, stride_kh, stride_kn, stride_kk, stride_vb, stride_vh, stride_vn, stride_vk, stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, stride_deltab, stride_deltah, stride_deltam, stride_dob, stride_doh, stride_dom, stride_dok, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, @@ -1939,9 +1949,10 @@ def _bwd_kernel_dkdvdq_noncausal( for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): adj_q = (bid * stride_qb + hqid * stride_qh + q_start * stride_qm) - + adj_dq = (bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm) + Q_ptr = Q + adj_q - DQ_ptr = DQ + adj_q + DQ_ptr = DQ + adj_dq adj_do = (bid * stride_dob + hqid * stride_doh + q_start * stride_dom) DO_ptr = DO + adj_do @@ -1973,6 +1984,7 @@ def _bwd_kernel_dkdvdq_noncausal( dk, dv, Q_ptr, k, v, DO_ptr, DQ_ptr, M_ptr, Delta_ptr, sm_scale, stride_qm, stride_qk, + stride_dqm, stride_dqk, stride_dom, stride_dok, stride_dropoutm, stride_dropoutn, stride_deltam, @@ -2372,7 +2384,7 @@ def _flash_attn_backward( if fused: # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups - BLOCK_N = 128 + BLOCK_N = 128 if BLOCK_D_MODEL_POW2 < 160 else 64 # larger head sizes lead to oom config = { "BLOCK_M": 32, "BLOCK_N": BLOCK_N, @@ -2393,6 +2405,7 @@ def _flash_attn_backward( *k_strides, *v_strides, *dk_strides, + *dq_strides, *delta_strides, *do_strides, *dropout_strides, @@ -2422,6 +2435,7 @@ def _flash_attn_backward( *k_strides, *v_strides, *dk_strides, + *dq_strides, *delta_strides, *do_strides, *dropout_strides, @@ -2776,6 +2790,7 @@ def forward( return_lse, return_softmax, is_grad_enabled, + fused_backward, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q,k,v] @@ -2812,7 +2827,7 @@ def forward( cu_seqlens_k=None, descale_q=descale_q, descale_k=descale_k, - descale_v=descale_v + descale_v=descale_v, ) if is_grad: @@ -2824,6 +2839,7 @@ def forward( ctx.causal = causal ctx.window_size = window_size ctx.alibi_slopes = alibi_slopes + ctx.fused_backward = fused_backward out = out_padded[..., :head_size_og] result = [out] @@ -2869,11 +2885,12 @@ def backward(ctx, do, *args): descale_k=descale_k, descale_v=descale_v, descale_do=descale_do, + fused=ctx.fused_backward, ) #dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension #dk = dk[..., : k_fp8.shape[-1]] #dv = dv[..., : v_fp8.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None def flash_attn_fp8_func( q, @@ -2886,7 +2903,8 @@ def flash_attn_fp8_func( alibi_slopes=None, deterministic=False, return_lse=False, - return_attn_probs=False + return_attn_probs=False, + fused_backward=False, ): return FlashAttnFP8Func.apply( q, @@ -2900,7 +2918,8 @@ def flash_attn_fp8_func( deterministic, return_lse, return_attn_probs, - torch.is_grad_enabled() + torch.is_grad_enabled(), + fused_backward, ) class FlashAttnVarlenFunc(torch.autograd.Function): @@ -3127,6 +3146,7 @@ def forward( return_softmax, block_table, is_grad_enabled, + fused_backward, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, k, v] @@ -3163,7 +3183,8 @@ def forward( cu_seqlens_k=cu_seqlens_k, descale_q=descale_q, descale_k=descale_k, - descale_v=descale_v + descale_v=descale_v, + fused_backward=fused_backward, ) if is_grad: ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, descale_q, descale_k, descale_v) @@ -3176,6 +3197,7 @@ def forward( ctx.causal = causal ctx.window_size = window_size ctx.alibi_slopes = alibi_slopes + ctx.fused_backward = fused_backward out = out_padded[..., :head_size_og] result = [out] if return_lse: @@ -3187,15 +3209,15 @@ def forward( @staticmethod def backward(ctx, do, *args): - q_fp8, k_fp8, v_fp8, out, softmax_lse, cu_seqlens_q, cu_seqlens_q, descale_q, descale_k, descale_v = ctx.saved_tensors - dq, dk, dv = torch.zeros_like(q, dtype=torch.float32), torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32) + q_fp8, k_fp8, v_fp8, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, descale_q, descale_k, descale_v = ctx.saved_tensors + dq, dk, dv = torch.zeros_like(q_fp8, dtype=torch.float32), torch.zeros_like(k_fp8, dtype=torch.float32), torch.zeros_like(v_fp8, dtype=torch.float32) head_size_v_og = do.size(3) do_padded = do if head_size_v_og % 8 != 0: do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) fp8_dtype = torch.float8_e4m3fnuz - do_padded_fp8, descale_do = cast_varlen_to_fp8(dout_padded, fp8_dtype, "thd", cu_seqlens_q) + do_padded_fp8, descale_do = cast_varlen_to_fp8(do_padded, fp8_dtype, "thd", cu_seqlens_q) _flash_attn_backward( do_padded_fp8, @@ -3212,8 +3234,8 @@ def backward(ctx, do, *args): ctx.causal, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_k=ctx.max_seqlen_k, dropout_p=ctx.dropout_p, philox_seed=ctx.philox_seed, philox_offset=ctx.philox_offset, @@ -3222,10 +3244,10 @@ def backward(ctx, do, *args): descale_v=descale_v, descale_do=descale_do ) - dq = dq[..., : q.shape[-1]] # We could have padded the head dimension - dk = dk[..., : k.shape[-1]] - dv = dv[..., : v.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None + dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k_fp8.shape[-1]] + dv = dv[..., : v_fp8.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None def flash_attn_varlen_fp8_func( q, @@ -3243,7 +3265,8 @@ def flash_attn_varlen_fp8_func( deterministic=False, return_lse=False, return_attn_probs=False, - block_table=None + block_table=None, + fused_backward=False, ): return FlashAttnVarlenFP8Func.apply( q, @@ -3262,5 +3285,6 @@ def flash_attn_varlen_fp8_func( return_lse, return_attn_probs, block_table, - torch.is_grad_enabled() - ) \ No newline at end of file + torch.is_grad_enabled(), + fused_backward, + ) diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py index 3f650d288db..9f8a1ab46a2 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py @@ -2,8 +2,8 @@ import triton # type: ignore import triton.language as tl # type: ignore from typing import Literal, Optional -from .utils import AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, get_shapes_from_layout, compute_fp8_scaling_factors, \ - get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_cdna, is_rdna +from .utils import DEBUG, AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, get_shapes_from_layout, compute_fp8_scaling_factors, \ + get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_cdna, is_fp8, is_rdna # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) @@ -108,7 +108,7 @@ def get_autotune_configs(): def _bwd_preprocess( O, DO, # noqa: E741 Delta, - stride_ob, stride_oh, stride_om, stride_ok, + stride_ob, stride_oh, stride_om, stride_od, stride_deltab, stride_deltah, stride_deltam, stride_descale_do_z, cu_seqlens_q, max_seqlen_q, @@ -135,7 +135,7 @@ def _bwd_preprocess( # Compute offsets offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) - offs_k = tl.arange(0, HEAD_DIM) + offs_d = tl.arange(0, HEAD_DIM) # Offset O/DO by batch, head and q_start O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741 DO += bid * stride_ob + hid * stride_oh + q_start * stride_om @@ -144,9 +144,9 @@ def _bwd_preprocess( mask_md = mask_m[:, None] PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) if PADDED_HEAD: - mask_md &= offs_k[None, :] < ACTUAL_HEAD_DIM + mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM # compute pointers - offs_do = offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok + offs_do = offs_m[:, None] * stride_om + offs_d[None, :] * stride_od out_ptrs = O + offs_do do_ptrs = DO + offs_do # load @@ -171,19 +171,21 @@ def _bwd_dkdv_inner( Q, k, v, DO, M, D, sm_scale, # input tensor stride_qm, stride_qk, stride_dom, stride_dok, - stride_dropoutm, stride_dropoutn, # + stride_dropoutm, stride_dropoutn, stride_deltam, BLOCK_M: tl.constexpr, # 16 BLOCK_N: tl.constexpr, # 128 HEAD_DIM: tl.constexpr, # ACTUAL_HEAD_DIM: tl.constexpr, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k # Filled in by the wrapper. start_n, start_m, num_steps, # iteration numbers descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal ENABLE_DROPOUT: tl.constexpr, # activate dropout + USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, # activate exp2 IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, @@ -246,16 +248,23 @@ def _bwd_dkdv_inner( qkT = (tl.dot(k, qT) * descale_q * descale_k) else: qkT = tl.dot(k, qT) + qkT_scaled = qkT * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_n[:, None] + seqlen_q - seqlen_k - offs_m[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qkT_scaled += alibi_block + if DEBUG_TRITON_DETAIL: if start_n == 256: print(f"qT: {qT.shape}\n", qT) print(f"k: {k.shape}\n", k) - print(f"qkT scaled: {qkT.shape}\n", qkT * sm_scale) + print(f"qkT scaled: {qkT.shape}\n", qkT_scaled) # TODO: remove the scaling of m later when we removed re-scaling in fwd if USE_EXP2: - pT = tl.math.exp2(qkT * sm_scale * RCP_LN2 - m[None, :] * RCP_LN2) + pT = tl.math.exp2(qkT_scaled * RCP_LN2 - m[None, :] * RCP_LN2) else: - pT = tl.math.exp(qkT * sm_scale - m[None, :]) + pT = tl.math.exp(qkT_scaled - m[None, :]) # Autoregressive masking. if MASK: @@ -323,12 +332,14 @@ def _bwd_dq_inner( BLOCK_N2: tl.constexpr, # HEAD_DIM: tl.constexpr, ACTUAL_HEAD_DIM: tl.constexpr, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, # Filled in by the wrapper. start_m, start_n, end_n, num_steps, # descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user MASK: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, + USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, @@ -392,11 +403,18 @@ def _bwd_dq_inner( qk = (tl.dot(q, kT) * descale_q * descale_k) else: qk = tl.dot(q, kT) - if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk * sm_scale) # noqa: E701 + qk_scaled = qk * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qk_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk_scaled) # noqa: E701 if USE_EXP2: - p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) + p = tl.math.exp2(qk_scaled * RCP_LN2 - m * RCP_LN2) else: - p = tl.math.exp(qk * sm_scale - m) + p = tl.math.exp(qk_scaled - m) # Autoregressive masking. if MASK: @@ -434,18 +452,23 @@ def _bwd_dq_inner( def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) Q, K, V, sm_scale, DO, DQ, DK, DV, M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_qb, stride_qh, stride_qm, stride_qd, + stride_kb, stride_kh, stride_kn, stride_kd, + stride_vb, stride_vh, stride_vn, stride_vd, + stride_dqb, stride_dqh, stride_dqm, stride_dqd, + stride_dkb, stride_dkh, stride_dkn, stride_dkd, + stride_dvb, stride_dvh, stride_dvn, stride_dvd, stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, + stride_dob, stride_doh, stride_dom, stride_dod, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, HQ, HK, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, + Dropout_mask, dropout_p, philox_seed, philox_offset_base, + Alibi_slopes, + Descale_q, Descale_k, Descale_v, Descale_do, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, BLOCK_M2: tl.constexpr, @@ -455,7 +478,11 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead ACTUAL_HEAD_DIM: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, ): @@ -481,7 +508,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead delta_qk = seqlen_q - seqlen_k if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701 PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - offs_k = tl.arange(0, HEAD_DIM) + offs_d = tl.arange(0, HEAD_DIM) GROUP_SIZE: tl.constexpr = HQ // HK # align the delta_qk @@ -511,15 +538,15 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead # Mask for loading K and V mask_kv = offs_n[:, None] < seqlen_k if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_kv &= mask_k[None, :] - offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + mask_d = offs_d < ACTUAL_HEAD_DIM + mask_kv &= mask_d[None, :] # K/V tensors not changed for the group - adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) - v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. # hqid = hkid for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): @@ -546,6 +573,12 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead M_ptr = M + adj_delta Delta_ptr = Delta + adj_delta + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + # batch_philox_offset is the ACTUALLY dropout offset # dropout_offset is for debug purpose and will be removed later batch_philox_offset = 0 @@ -553,9 +586,17 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead if ENABLE_DROPOUT: batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ hqid * stride_dropouth - dropout_offset = dropout_mask + bid * stride_dropoutb + \ + dropout_offset = Dropout_mask + bid * stride_dropoutb + \ hqid * stride_dropouth + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR # bound the masked operation to q len so it does not have to wast cycles len_m = min(len_m, seqlen_q) @@ -570,21 +611,23 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead dk, dv = _bwd_dkdv_inner( dk, dv, # output tensors Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o + stride_qm, stride_qd, # strides for q + stride_dom, stride_dod, # strides for o stride_dropoutm, stride_dropoutn, # strides for dropout stride_deltam, MASK_BLOCK_M1, BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k start_n, start_m, num_steps, # iteration numbers - None, None, None, None, + descale_q, descale_k, descale_v, descale_do, MASK=True, # causal masking ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) @@ -598,31 +641,36 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead dk, dv = _bwd_dkdv_inner( dk, dv, # output tensors Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o + stride_qm, stride_qd, # strides for q + stride_dom, stride_dod, # strides for o stride_dropoutm, stride_dropoutn, # strides for dropout stride_deltam, BLOCK_M1, BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k start_n, start_m, num_steps, # iteration numbers - None, None, None, None, + descale_q, descale_k, descale_v, descale_do, MASK=False, # causal masking ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) # end of GQA/MQA of dkdv - # Write back dV and dK. - adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn - offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk - tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + # Write back dV + adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn + offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) + # write back dk + adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn + offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd dk *= sm_scale - tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) # This part does dq start_m = pid * BLOCK_M2 @@ -637,13 +685,14 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead # Mask for loading K and V mask_q = offs_m[:, None] < seqlen_q if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk - offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - K += adj_kv - V += adj_kv + mask_d = offs_d < ACTUAL_HEAD_DIM + mask_q &= mask_d[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + # NOTE: don't assume that the strides for k and v are the same! + K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn + V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn + # If MQA / GQA, set the K and V head offsets appropriately. for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front @@ -659,6 +708,12 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam Delta_ptr = Delta + adj_delta + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + # batch_philox_offset is the ACTUALLY dropout offset # dropout_offset is for debug purpose and will be removed later batch_philox_offset = 0 @@ -668,8 +723,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead bid * stride_dropoutb + \ hqid * stride_dropouth dropout_offset = \ - dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) m = tl.load(M + adj_delta + offs_m * stride_deltam, @@ -680,24 +734,35 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead # start can only be 0 at minimum start_n = max(end_n - BLOCK_M2, 0) num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) dq = _bwd_dq_inner( dq, - q, K, V, do, m, Delta_ptr, sm_scale, # - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, # + q, K, V, do, m, Delta_ptr, sm_scale, + stride_qm, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd, + stride_dropoutm, stride_dropoutn, stride_deltam, - seqlen_q, seqlen_k, # - BLOCK_M2, MASK_BLOCK_N2, # - HEAD_DIM, ACTUAL_HEAD_DIM, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - start_m, start_n, end_n, num_steps, # - None, None, None, None, + seqlen_q, seqlen_k, + BLOCK_M2, MASK_BLOCK_N2, + HEAD_DIM, ACTUAL_HEAD_DIM, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, MASK=True, # ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) @@ -706,28 +771,30 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead start_n = max(end_n - num_steps * BLOCK_N2, 0) if DEBUG_TRITON: print(f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 dq = _bwd_dq_inner( - dq, # - q, K, V, do, m, Delta_ptr, sm_scale, # - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, # - stride_dropoutm, stride_dropoutn, # + dq, + q, K, V, do, m, Delta_ptr, sm_scale, + stride_qm, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd, + stride_dropoutm, stride_dropoutn, stride_deltam, - seqlen_q, seqlen_k, # - BLOCK_M2, BLOCK_N2, # - HEAD_DIM, ACTUAL_HEAD_DIM, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - start_m, start_n, end_n, num_steps, # - None, None, None, None, - MASK=False, # + seqlen_q, seqlen_k, + BLOCK_M2, BLOCK_N2, + HEAD_DIM, ACTUAL_HEAD_DIM, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + MASK=False, ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) # Write back dQ. adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) # end of GQA/MQA of dq @@ -741,18 +808,23 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead def bwd_kernel_noncausal( Q, K, V, sm_scale, DO, DQ, DK, DV, M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_qb, stride_qh, stride_qm, stride_qd, + stride_kb, stride_kh, stride_kn, stride_kd, + stride_vb, stride_vh, stride_vn, stride_vd, + stride_dqb, stride_dqh, stride_dqm, stride_dqd, + stride_dkb, stride_dkh, stride_dkn, stride_dkd, + stride_dvb, stride_dvh, stride_dvn, stride_dvd, stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, + stride_dob, stride_doh, stride_dom, stride_dod, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, HQ, HK, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, + Dropout_mask, dropout_p, philox_seed, philox_offset_base, + Alibi_slopes, + Descale_q, Descale_k, Descale_v, Descale_do, BLOCK_M1: tl.constexpr, # 32 BLOCK_N1: tl.constexpr, # 128 BLOCK_M2: tl.constexpr, # 128 @@ -762,7 +834,11 @@ def bwd_kernel_noncausal( ACTUAL_HEAD_DIM: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, ): @@ -786,7 +862,7 @@ def bwd_kernel_noncausal( seqlen_k = k_end - k_start PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - offs_k = tl.arange(0, HEAD_DIM) + offs_d = tl.arange(0, HEAD_DIM) GROUP_SIZE: tl.constexpr = HQ // HK start_n = pid * BLOCK_N1 @@ -798,15 +874,15 @@ def bwd_kernel_noncausal( # Mask for loading K and V mask_kv = offs_n[:, None] < seqlen_k if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_kv &= mask_k[None, :] - offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk - + mask_d = offs_d < ACTUAL_HEAD_DIM + mask_kv &= mask_d[None, :] + # NOTE: don't assume that the strides for k and v are the same! # K/V tensors not changed for the group - adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) - v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): # offset input and output tensor by batch and Q/K heads @@ -818,6 +894,12 @@ def bwd_kernel_noncausal( M_ptr = M + adj_delta Delta_ptr = Delta + adj_delta + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + # batch_philox_offset is the ACTUALLY dropout offset # dropout_offset is for debug purpose and will be removed later batch_philox_offset = 0 @@ -825,40 +907,53 @@ def bwd_kernel_noncausal( if ENABLE_DROPOUT: batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ hqid * stride_dropouth - dropout_offset = dropout_mask + bid * stride_dropoutb + \ + dropout_offset = Dropout_mask + bid * stride_dropoutb + \ hqid * stride_dropouth + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + # because there is no causal, we always start from the beginning start_m = 0 num_steps = tl.cdiv(seqlen_q, BLOCK_M1) dk, dv = _bwd_dkdv_inner( dk, dv, # output tensors Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o + stride_qm, stride_qd, # strides for q + stride_dom, stride_dod, # strides for o stride_dropoutm, stride_dropoutn, # strides for dropout stride_deltam, BLOCK_M1, BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k start_n, start_m, num_steps, # iteration numbers - None, None, None, None, + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user MASK=False, # causal masking ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) - # Write back dV and dK. - adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn - offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk - tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + # Write back dV + adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn + offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) + # write back dk + adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn + offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd dk *= sm_scale - tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) # THIS PART DOES DQ start_m = pid * BLOCK_M2 @@ -867,13 +962,12 @@ def bwd_kernel_noncausal( # Mask for loading K and V mask_q = offs_m[:, None] < seqlen_q if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk - offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - K += adj_kv - V += adj_kv + mask_d = offs_d < ACTUAL_HEAD_DIM + mask_q &= mask_d[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn + V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn # If MQA / GQA, set the K and V head offsets appropriately. for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): # offset input and output tensor by batch and Q/K heads @@ -883,6 +977,12 @@ def bwd_kernel_noncausal( bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam Delta_ptr = Delta + adj_delta + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + # batch_philox_offset is the ACTUALLY dropout offset # dropout_offset is for debug purpose and will be removed later batch_philox_offset = 0 @@ -892,7 +992,7 @@ def bwd_kernel_noncausal( bid * stride_dropoutb + \ hqid * stride_dropouth dropout_offset = \ - dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) @@ -900,6 +1000,14 @@ def bwd_kernel_noncausal( mask=offs_m < seqlen_q) m = m[:, None] + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + # start can only be 0 at minimum start_n = 0 end_n = seqlen_k @@ -907,31 +1015,39 @@ def bwd_kernel_noncausal( dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) dq = _bwd_dq_inner( - dq, # - q, K, V, do, m, Delta_ptr, sm_scale, # - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, # - stride_dropoutm, stride_dropoutn, # + dq, + q, K, V, do, m, Delta_ptr, sm_scale, + stride_qm, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd, + stride_dropoutm, stride_dropoutn, stride_deltam, - seqlen_q, seqlen_k, # - BLOCK_M2, BLOCK_N2, # - HEAD_DIM, ACTUAL_HEAD_DIM, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - start_m, start_n, end_n, num_steps, # - None, None, None, None, - MASK=False, # + seqlen_q, seqlen_k, + BLOCK_M2, BLOCK_N2, + HEAD_DIM, ACTUAL_HEAD_DIM, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + MASK=False, ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - IS_FP8=False, - FP8_MAX=None, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) # Write back dQ. adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) +def is_contiguous(x, name): + if x.is_contiguous(): + return x + else: + print(f"{name} is not contiguous") + return x.contiguous() def attention_prefill_backward_triton_split_oneKernel_impl( do: torch.Tensor, @@ -955,11 +1071,55 @@ def attention_prefill_backward_triton_split_oneKernel_impl( philox_seed: Optional[int], philox_offset: Optional[int], use_exp2: bool, + # fp8 + descale_q: Optional[torch.Tensor], + descale_k: Optional[torch.Tensor], + descale_v: Optional[torch.Tensor], + descale_o: Optional[torch.Tensor], + descale_do: Optional[torch.Tensor], + descale_dq: Optional[torch.Tensor], + descale_dk: Optional[torch.Tensor], + descale_dv: Optional[torch.Tensor], ): # debug DEBUG_TRITON: bool = False DEBUG_TRITON_DETAIL: bool = False + # do = is_contiguous(do, "do") + # q = is_contiguous(q, "q") + # k = is_contiguous(k, "k") + # v = is_contiguous(v, "v") + # o = is_contiguous(o, "o") + # softmax_lse = is_contiguous(softmax_lse, "softmax_lse") + # dq = is_contiguous(dq, "dq") + # dk = is_contiguous(dk, "dk") + # dv = is_contiguous(dv, "dv") + + IS_FP8 = is_fp8(q) + if IS_FP8: + FP8_MAX = torch.finfo(q.dtype).max + # assert that the main inputs are fp8 + assert is_fp8(do) and is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: do.dtype={do.dtype}, q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8." + if is_fp8(o): + FP8_OUTPUT = True + assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." + assert descale_dq is not None, f"descale_dq is None. In fp8, you need to pass a tensor for descale_dq along with a tensor dq." + assert descale_dk is not None, f"descale_dk is None. In fp8, you need to pass a tensor for descale_dk along with a tensor dk." + assert descale_dv is not None, f"descale_dv is None. In fp8, you need to pass a tensor for descale_dv along with a tensor dv." + else: + FP8_OUTPUT = False + + stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None + stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None + stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None + stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None + stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None + else: + FP8_MAX = None + FP8_OUTPUT = False + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = stride_descale_do_z = None + + # get strides and shape batch, nheads_q, nheads_k, head_size, max_seqlen_q_final, max_seqlen_k_final = \ get_shapes_from_layout( @@ -969,21 +1129,23 @@ def attention_prefill_backward_triton_split_oneKernel_impl( ) q_strides, k_strides, v_strides, o_strides = \ get_strides_from_layout(q, k, v, o, layout) - stride_qb, stride_qh, stride_qm, stride_qk = q_strides - stride_kb, stride_kh, stride_kn, stride_kk = k_strides - stride_vb, stride_vh, stride_vn, stride_vk = v_strides - stride_ob, stride_oh, stride_om, stride_ok = o_strides - dq_strides, dk_strides, _, do_strides = \ + stride_qb, stride_qh, stride_qm, stride_qd = q_strides + stride_kb, stride_kh, stride_kn, stride_kd = k_strides + stride_vb, stride_vh, stride_vn, stride_vd = v_strides + stride_ob, stride_oh, stride_om, stride_od = o_strides + dq_strides, dk_strides, dv_strides, do_strides = \ get_strides_from_layout(dq, dk, dv, do, layout) - stride_dqb, stride_dqh, stride_dqm, stride_dqk = dq_strides - stride_dkb, stride_dkh, stride_dkn, stride_dkk = dk_strides - stride_dob, stride_doh, stride_dom, stride_dok = do_strides + stride_dqb, stride_dqh, stride_dqm, stride_dqd = dq_strides + stride_dkb, stride_dkh, stride_dkn, stride_dkd = dk_strides + stride_dvb, stride_dvh, stride_dvn, stride_dvd = dv_strides + stride_dob, stride_doh, stride_dom, stride_dod = do_strides IS_VARLEN = layout == "thd" use_dropout = (dropout_p > 0.0) + use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) # get closest power of 2 over or equal to 32. padded_d_model = 1 << (head_size - 1).bit_length() - padded_d_model = max(padded_d_model, 16) + padded_d_model = max(padded_d_model, 32) HEAD_DIM = padded_d_model ACTUAL_HEAD_DIM = head_size @@ -998,17 +1160,20 @@ def attention_prefill_backward_triton_split_oneKernel_impl( _bwd_preprocess[pre_grid]( o, do, delta, - stride_ob, stride_oh, stride_om, stride_ok, + stride_ob, stride_oh, stride_om, stride_od, stride_deltab, stride_deltah, stride_deltam, - 0, + stride_descale_do_z, cu_seqlens_q, max_seqlen_q_final, - None, + descale_do, HEAD_DIM=HEAD_DIM, ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, IS_VARLEN=IS_VARLEN, - IS_FP8=False + IS_FP8=IS_FP8 ) + if DEBUG: + print("delta:", delta, delta.shape) + # dropout mask tensor for debugging. We dump the dropout mask created in # the kernel for testing dropout_mask = None @@ -1043,23 +1208,32 @@ def attention_prefill_backward_triton_split_oneKernel_impl( bwd_kernel_causal[grid]( q, k, v, sm_scale, do, dq, dk, dv, softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_qb, stride_qh, stride_qm, stride_qd, + stride_kb, stride_kh, stride_kn, stride_kd, + stride_vb, stride_vh, stride_vn, stride_vd, + stride_dqb, stride_dqh, stride_dqm, stride_dqd, + stride_dkb, stride_dkh, stride_dkn, stride_dkd, + stride_dvb, stride_dvh, stride_dvn, stride_dvd, stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, + stride_dob, stride_doh, stride_dom, stride_dod, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q_final, max_seqlen_k_final, dropout_mask, dropout_p, philox_seed, philox_offset, - HEAD_DIM=HEAD_DIM, + alibi_slopes, + descale_q, descale_k, descale_v, descale_do, + HEAD_DIM=HEAD_DIM, ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) @@ -1067,23 +1241,32 @@ def attention_prefill_backward_triton_split_oneKernel_impl( bwd_kernel_noncausal[grid]( q, k, v, sm_scale, do, dq, dk, dv, softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_qb, stride_qh, stride_qm, stride_qd, + stride_kb, stride_kh, stride_kn, stride_kd, + stride_vb, stride_vh, stride_vn, stride_vd, + stride_dqb, stride_dqh, stride_dqm, stride_dqd, + stride_dkb, stride_dkh, stride_dkn, stride_dkd, + stride_dvb, stride_dvh, stride_dvn, stride_dvd, stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, + stride_dob, stride_doh, stride_dom, stride_dod, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q_final, max_seqlen_k_final, dropout_mask, dropout_p, philox_seed, philox_offset, + alibi_slopes, + descale_q, descale_k, descale_v, descale_do, HEAD_DIM=HEAD_DIM, ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) diff --git a/flash_attn/flash_attn_triton_amd/bwd_ref.py b/flash_attn/flash_attn_triton_amd/bwd_ref.py index 90a98ce4fcc..639211a51f6 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/bwd_ref.py @@ -122,7 +122,7 @@ def attention_backward_core_ref_impl( print("dp:", dp, dp.shape) # calculate ds - if False: + if True: delta = torch.sum(o * do, axis=-1).unsqueeze(-1) else: delta = torch.sum(p * dp, axis=-1).unsqueeze(-1) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 6f69cd02813..4d9a7bf17b1 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -2,7 +2,7 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_shapes_from_layout, get_strides_from_layout, is_cdna, is_fp8, is_rdna, create_dropout_mask +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_arch, get_shapes_from_layout, get_strides_from_layout, is_cdna, is_fp8, is_rdna, create_dropout_mask # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) @@ -80,22 +80,24 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE + if USE_ALIBI: + # compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, + global_n_positions) + qk_scaled += alibi_block + if IS_CAUSAL: causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) + if bias_ptrs is not None: bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) qk_scaled += bias - if USE_ALIBI: - # compute the global position of each token within the sequence - global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - global_n_positions = start_n + tl.arange(0, BLOCK_N) - alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, - global_n_positions) - qk_scaled += alibi_block # get max scores so far m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) @@ -213,12 +215,28 @@ def get_autotune_configs(): else: raise ValueError("Unknown Device Type") else: - return [ - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + arch = get_arch() + if arch == "gfx950": + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, num_warps=4, - ), + ) + elif arch == "gfx942" and False: # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + else: + default_config = triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + + return [ + default_config ], [ "IS_CAUSAL", "dropout_p", @@ -250,14 +268,29 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_INFERENCE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr): + IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr, FLIP_GRID: tl.constexpr): # set params ACCUMULATOR_TYPE = tl.float32 # compute offsets - start_m = tl.program_id(0) - off_h_q = tl.program_id(1) - off_z = tl.program_id(2) + if FLIP_GRID: + #NUM_XCDS: tl.constexpr = 8 + off_z = tl.program_id(0) + off_h_q = tl.program_id(1) + start_m = tl.program_id(2) + + #start_m = (tl.cdiv(MAX_SEQLENS_Q, BLOCK_M) - 1) - start_m + + # Remap heads to the same XCD + #pids_per_xcd = HQ // NUM_XCDS + #xcd_group = off_h_q % NUM_XCDS + #pid_in_xcd = off_h_q // NUM_XCDS + #off_h_q = xcd_group * pids_per_xcd + pid_in_xcd + else: + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) @@ -308,7 +341,7 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) - o_ptrs_mask = offs_m[:, None] < seqlen_q + o_ptrs_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) # We still need to write 0s to the result tl.store(o_ptrs, acc, mask=o_ptrs_mask) # The tensor allocated for L is based on MAX_SEQLENS_Q as that is @@ -598,7 +631,11 @@ def attention_prefill_forward_triton_impl( # kernel is padded - there is no padding in memory for any dims. padded_d_model = max(padded_d_model, 16) - grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) + FLIP_GRID = True + if FLIP_GRID: + grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META['BLOCK_M'])) + else: + grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according @@ -644,6 +681,6 @@ def attention_prefill_forward_triton_impl( MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, IS_VARLEN=is_varlen, IS_INFERENCE=is_inference, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p - > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT) + > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT, FLIP_GRID=FLIP_GRID) return softmax_lse, sd_mask if return_softmax else None diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 06ab7d24d56..cbe597c5dbe 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -13,6 +13,10 @@ from flash_attn.layers.rotary import apply_rotary_emb from typing import Literal, Optional, Union + +USE_EXP2 = True +BWD_MODE = os.environ.get('BWD_MODE', 'jingning').lower() + def fwd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -66,12 +70,19 @@ def fwd(q: torch.Tensor, if return_softmax: metadata.return_scores = True - batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, metadata.layout) + # get shape + batch, _ , nheads_q, _= q.shape if causal: metadata.need_causal(True) if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError(f"Alibi can be (nheads,) or (batch_size, nheads). Given tensor with shape {alibi_slopes.shape}") metadata.need_alibi(alibi_slopes, batch, nheads_q) # store rng state @@ -101,7 +112,7 @@ def fwd(q: torch.Tensor, metadata.dropout_p, metadata.philox_seed, metadata.philox_offset, - metadata.use_exp2) + USE_EXP2) softmax_lse=softmax_lse_ref sd_mask=sd_mask_ref else: @@ -127,7 +138,7 @@ def fwd(q: torch.Tensor, metadata.philox_seed, metadata.philox_offset, metadata.return_scores, - metadata.use_exp2, + USE_EXP2, descale_q, descale_k, descale_v, @@ -145,7 +156,6 @@ def fwd(q: torch.Tensor, return out, softmax_lse, sd_mask, rng_state -BWD_MODE = os.environ.get('BWD_MODE', 'split').lower() def bwd( dout: torch.Tensor, q: torch.Tensor, @@ -210,11 +220,23 @@ def bwd( dk = torch.zeros_like(k) if dk is None else dk.zero_() dv = torch.zeros_like(v) if dv is None else dv.zero_() - if rng_state is not None: + # get shape + batch, _ , nheads_q, _= q.shape + + if dropout_p > 0.0: + assert rng_state is not None philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() else: philox_seed, philox_offset = None, None + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") + # call implementation if USE_REF: if DEBUG: @@ -241,7 +263,7 @@ def bwd( dropout_p, philox_seed, philox_offset, - False, + USE_EXP2, ) delta = delta_ref else: @@ -269,7 +291,7 @@ def bwd( dropout_p, philox_seed, philox_offset, - False, + USE_EXP2, descale_q, descale_k, descale_v, @@ -330,7 +352,15 @@ def bwd( dropout_p, philox_seed, philox_offset, - False + USE_EXP2, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, ) delta = delta_triton else: @@ -411,13 +441,20 @@ def varlen_fwd( metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # set layout to "thd" and other metdata assert metadata.layout is not None - # get shapes - batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shapes_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) + # get shape + batch = len(cu_seqlens_q) - 1 + _, nheads_q, _= q.shape if causal: metadata.need_causal(True) if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") metadata.need_alibi(alibi_slopes, batch, nheads_q) # store rng state @@ -447,7 +484,7 @@ def varlen_fwd( metadata.dropout_p, metadata.philox_seed, metadata.philox_offset, - metadata.use_exp2) + USE_EXP2) softmax_lse=softmax_lse_ref sd_mask=sd_mask_ref else: @@ -473,7 +510,7 @@ def varlen_fwd( metadata.philox_seed, metadata.philox_offset, metadata.return_scores, - metadata.use_exp2, + USE_EXP2, descale_q, descale_k, descale_v, @@ -558,11 +595,24 @@ def varlen_bwd( dk = torch.zeros_like(k) if dk is None else dk.zero_() dv = torch.zeros_like(v) if dv is None else dv.zero_() - if rng_state is not None: + # get shape + batch = len(cu_seqlens_q) - 1 + _, nheads_q, _= q.shape + + if dropout_p > 0.0: + assert rng_state is not None philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() else: philox_seed, philox_offset = None, None + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") + # call implementation if USE_REF: if DEBUG: @@ -588,44 +638,108 @@ def varlen_bwd( dropout_p, philox_seed, philox_offset, - False, + USE_EXP2, ) delta = delta_ref else: if DEBUG: print("Using Triton implementation") - delta_triton = attention_prefill_backward_triton_split_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - "thd", - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - False, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - delta = delta_triton + if BWD_MODE == "split": + delta_triton = attention_prefill_backward_triton_split_impl( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + "thd", + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + USE_EXP2, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, + ) + delta = delta_triton + elif BWD_MODE == "fused": + delta_triton = attention_prefill_backward_triton_fused_impl( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + descale_o, + True, + ) + delta = delta_triton + elif BWD_MODE == "jingning": + delta_triton = attention_prefill_backward_triton_split_oneKernel_impl( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + "thd", + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + USE_EXP2, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, + ) + delta = delta_triton + else: + raise ValueError(f"Unknown bwd mode {BWD_MODE}") if DEBUG: print("varlen_bwd outputs") @@ -697,11 +811,19 @@ def fwd_kvcache( k_new = k v_new = v + # get shape + batch, _ , nheads_q, _= q.shape + if causal: metadata.need_causal(True) if alibi_slopes is not None: - batch, _ , nheads_q, _= q.shape + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") metadata.need_alibi(alibi_slopes, batch, nheads_q) # rotary boolean @@ -779,7 +901,7 @@ def fwd_kvcache( metadata.philox_seed, metadata.philox_offset, metadata.return_scores, - metadata.use_exp2, + USE_EXP2, None, None, None, diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index 58e2ae5fc7f..ea82de065b5 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -23,7 +23,7 @@ from .utils import DEBUG, input_helper, arch_supports_fp8 from .fwd_ref import attention_forward_pytorch_ref_impl from .fwd_prefill import attention_prefill_forward_triton_impl -from .bwd_prefill_split import attention_prefill_backward_triton_split_impl +from .bwd_prefill_onekernel import attention_prefill_backward_triton_split_oneKernel_impl from .bwd_ref import attention_backward_pytorch_ref_impl # set print options @@ -83,6 +83,7 @@ @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('use_exp2', [True, False]) # works when use_exp2 is false @pytest.mark.parametrize('DEBUG_INPUT', [False]) # NOTE: debug input can overflow when the tensors are large. Just use to figure out issues +@pytest.mark.skip() def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, alibi_slopes, layout, dtype, use_exp2, DEBUG_INPUT): torch.manual_seed(42) device = "cuda" @@ -96,7 +97,6 @@ def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dr print("MHA") # update metadata - metadata.use_exp2 = use_exp2 if causal: metadata.need_causal(True) @@ -129,7 +129,7 @@ def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dr metadata.philox_seed, metadata.philox_offset, metadata.return_scores, - metadata.use_exp2, + use_exp2, None, None, None, @@ -259,6 +259,7 @@ def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dr @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('use_exp2', [False]) # FIXME: using exp2 causes issue when used with causal @pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans on larger tensors +@pytest.mark.skip() def test_op_prefill_bwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, alibi_slopes, layout, dtype, use_exp2, DEBUG_INPUT): torch.manual_seed(20) device="cuda" @@ -333,7 +334,7 @@ def test_op_prefill_bwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dr dq_triton = torch.zeros_like(q_triton, dtype=q.dtype) # NOTE: the kernel does inplace accumlation on dq so dq has to be zeros dk_triton = torch.zeros_like(k_triton, dtype=k.dtype) if DEBUG_INPUT else torch.empty_like(k_triton, dtype=k.dtype) dv_triton = torch.zeros_like(v_triton, dtype=v.dtype) if DEBUG_INPUT else torch.empty_like(v_triton, dtype=v.dtype) - delta_triton = attention_prefill_backward_triton_split_impl( + delta_triton = attention_prefill_backward_triton_split_oneKernel_impl( do_triton, q_triton, k_triton, diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 5d3bf02e1f8..b129636e31d 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -12,6 +12,8 @@ # Gloabl Variables # ------------------------------- AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes') +if AUTOTUNE: + os.environ["TRITON_PRINT_AUTOTUNING"] = "1" DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes') USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes') @@ -48,7 +50,6 @@ class MetaData(): philox_seed: Optional[int] = None philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. - use_exp2: bool = False rotary_sin: Optional[torch.Tensor] = None rotary_cos: Optional[torch.Tensor] = None rotary_interleaved: bool = False diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index b5e026803c2..6073cb1c35a 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -16,7 +16,7 @@ from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb -from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_rdna +from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_hip, is_rdna MAX_HEADDIM_SM8x = 192 @@ -26,6 +26,8 @@ is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0) is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0) +skip_bfloat16 = True if is_sm75 or is_hip() else False + def attn_bias_from_alibi_slopes( slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None @@ -565,7 +567,7 @@ def get_dropout_fraction( return dropped.sum() / valid.sum() -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [False]) @@ -714,7 +716,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [True]) @@ -862,9 +864,9 @@ def test_flash_attn_varlen_qkvpacked( assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() -@pytest.mark.parametrize("kvpacked", [False]) +@pytest.mark.parametrize("kvpacked", [True, False]) # @pytest.mark.parametrize("kvpacked", [False]) -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @@ -1139,7 +1141,7 @@ def test_flash_attn_output( @pytest.mark.parametrize("kvpacked", [False]) # @pytest.mark.parametrize('kvpacked', [False]) -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize('mha_type', ["mqa"]) @@ -1459,7 +1461,7 @@ def test_flash_attn_varlen_output( assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @@ -1572,7 +1574,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @@ -1741,7 +1743,7 @@ def test_flash_attn_varlen_causal( assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [True]) @@ -1871,7 +1873,7 @@ def test_flash_attn_splitkv( assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("num_splits", [1, 0]) # @pytest.mark.parametrize("num_splits", [1]) @@ -2183,7 +2185,7 @@ def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [True]) @@ -2310,7 +2312,7 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): ).abs().max().item() + 1e-3 -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @@ -2400,7 +2402,7 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): assert not v.grad.isnan().any() -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @@ -2459,7 +2461,7 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc assert torch.equal(dq, dq0) -@pytest.mark.parametrize("dtype", ([torch.float16])) +@pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) From a9b2d5f329502d1b796b22fe460c2f0fffa26719 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 16 May 2025 06:10:28 +0300 Subject: [PATCH 02/27] head, seq, batch (#141) --- .../bwd_prefill_onekernel.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py index 9f8a1ab46a2..089676f5b0b 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py @@ -449,7 +449,7 @@ def _bwd_dq_inner( use_cuda_graph=True, ) @triton.jit -def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) +def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), batch) Q, K, V, sm_scale, DO, DQ, DK, DV, M, Delta, stride_qb, stride_qh, stride_qm, stride_qd, @@ -487,9 +487,9 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhead DEBUG_TRITON_DETAIL: tl.constexpr, ): # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) + hkid = tl.program_id(0) + pid = tl.program_id(1) + bid = tl.program_id(2) if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 # figure out varlen start and end q_start = 0 @@ -843,9 +843,9 @@ def bwd_kernel_noncausal( DEBUG_TRITON_DETAIL: tl.constexpr, ): # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) + hkid = tl.program_id(0) + pid = tl.program_id(1) + bid = tl.program_id(2) if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 # figure out varlen start and end q_start = 0 @@ -1202,7 +1202,7 @@ def attention_prefill_backward_triton_split_oneKernel_impl( dropout_mask.stride() seqlen = max(max_seqlen_q_final, max_seqlen_k_final) - grid = lambda META: ((seqlen + META['BLOCK_N1'] - 1) // META['BLOCK_N1'], batch, nheads_k) + grid = lambda META: (nheads_k, (seqlen + META['BLOCK_N1'] - 1) // META['BLOCK_N1'], batch, ) if causal: if DEBUG_TRITON: print(f"bwd_kernel: grid = {grid}" ) # noqa: E701 bwd_kernel_causal[grid]( From 87e6594cd93732c94741926a5283dfb2ce20fd9b Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 3 Jun 2025 15:20:36 -0400 Subject: [PATCH 03/27] Fix keys (#144) * save * rm keys * fix keys * use GHA_RENDER_DEVICES * normal docker --- .gitignore | 23 +++++++++++++++++-- .../bwd_prefill_onekernel.py | 8 +++---- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index dc508654045..9559348b043 100644 --- a/.gitignore +++ b/.gitignore @@ -32,5 +32,24 @@ var/ # Dev venv -# compile-time generated file -flash_attn_config.py \ No newline at end of file +# AMD +scripts +csrc/flash_attn_ck +.eggs +log +*.log +core.* +gpucore.* +*.csv +*.png +*.html +*.json +*.txt +*.pth +*.md +*.crt +training/logs +training/data +# ck modules +csrc/composable_kernel +csrc/cutlass diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py index 089676f5b0b..67f7498f083 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py @@ -69,21 +69,21 @@ def get_autotune_configs(): triton.Config({"PRE_BLOCK": PRE_BLOCK, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), ] preprocess_autotune_keys = [ - "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "max_seqlen_q", + "ACTUAL_HEAD_DIM", "IS_VARLEN", ] causal_autotune_configs = [ triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), ] causal_autotune_keys = [ - "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "dropout_p", "max_seqlen_q", "max_seqlen_k", "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", ] noncausal_autotune_configs = [ triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), ] noncausal_autotune_keys = [ - "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "dropout_p", "max_seqlen_q", "max_seqlen_k", "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", ] return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) From 2da51ae7c00c71770fe33a914f78060e5acdd7b0 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 26 Jun 2025 20:10:07 -0400 Subject: [PATCH 04/27] Pad LSE (#148) * add round multiple * fix fwd * backward fix * use rounded lse flag * passing ROUNDED_LSE * default is new rounded mode * rename to fused_atmoics and fused_no_atomics * add test for torch_compile * add varlen torch compile test * add old one kernel for ref * fix varlen mismatch bug * fix shape issue in varlen but mismatch * sync torch compile kernel launch * simple varlen test * add debug code * rm old * ignore old impls * DEBUG flag works in interface only * ref uses the righ shape for lse * rm oldest bwd kernel * fix typo * fix varlen bug * fix bug. Get info from q for now * simple shape and stride checkout * add more tests * test kvcache * kvcache safe * match case * fix segfault due to bad return_softmax * run bench * run seperate for the main functions * just output benchmark * default csv format and time stamp files * non verbsoe bench --- .github/workflows/amd_nightly.yml | 105 -- .github/workflows/amd_tests.yml | 4 +- flash_attn/flash_attn_triton_amd/.gitignore | 2 + flash_attn/flash_attn_triton_amd/bench.py | 311 ++-- .../flash_attn_triton_amd/bwd_prefill.py | 814 --------- ..._fused.py => bwd_prefill_fused_atomics.py} | 1487 +---------------- ...nel.py => bwd_prefill_fused_no_atomics.py} | 168 +- .../bwd_prefill_split.py | 4 +- flash_attn/flash_attn_triton_amd/bwd_ref.py | 19 +- .../flash_attn_triton_amd/fwd_decode.py | 4 +- .../flash_attn_triton_amd/fwd_prefill.py | 71 +- flash_attn/flash_attn_triton_amd/fwd_ref.py | 8 +- .../flash_attn_triton_amd/interface_fa.py | 36 +- flash_attn/flash_attn_triton_amd/test.py | 205 ++- flash_attn/flash_attn_triton_amd/utils.py | 388 ++++- 15 files changed, 930 insertions(+), 2696 deletions(-) delete mode 100644 .github/workflows/amd_nightly.yml create mode 100644 flash_attn/flash_attn_triton_amd/.gitignore delete mode 100644 flash_attn/flash_attn_triton_amd/bwd_prefill.py rename flash_attn/flash_attn_triton_amd/{bwd_prefill_fused.py => bwd_prefill_fused_atomics.py} (55%) rename flash_attn/flash_attn_triton_amd/{bwd_prefill_onekernel.py => bwd_prefill_fused_no_atomics.py} (89%) diff --git a/.github/workflows/amd_nightly.yml b/.github/workflows/amd_nightly.yml deleted file mode 100644 index 3131496ac49..00000000000 --- a/.github/workflows/amd_nightly.yml +++ /dev/null @@ -1,105 +0,0 @@ -name: AMD Nightly Kernel Tests - -on: - workflow_dispatch: - push: - branches: [main_perf] - schedule: - - cron: '0 0 * * *' # runs nightly at midnight UTC - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - Nightly-CDNA-AMD: - runs-on: ${{ matrix.runner }} - strategy: - matrix: - runner: [linux-mi300-gpu-1] - fail-fast: false # disables failing the entire job when one matrix entry fails - timeout-minutes: 720 # self hosted runners can run jobs for longer than the default of 360 minutes - container: - image: rocm/pytorch:latest - options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --shm-size 16G --group-add video --user root - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Show Device Info - run: | - rocminfo | grep gfx - - - name: Uninstall Triton - run: | - pip uninstall -y triton - rm -rf ~/.triton - rm -rf ./triton/python/build - - - name: Install Triton - run: | - pip install triton==3.3.0 - - - name: Show Triton version - run: | - pip show triton - - - name: Build - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install - - - name: Install dependencies for bench and misc - run: | - pip install matplotlib pandas tabulate - - - name: AMD Internal Tests - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest flash_attn/flash_attn_triton_amd/test.py - - - name: Flash Attention Tests - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest -n 8 tests/test_flash_attn_triton_amd.py - - - name: AMD Bench - run: | - python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_func flash_attn_varlen_func flash_attn_with_kvcache - - Nightly-RDNA-AMD: - runs-on: ${{ matrix.runner }} - strategy: - matrix: - runner: [gfx1100] - fail-fast: false # disables failing the entire job when one matrix entry fails - timeout-minutes: 720 # self hosted runners can run jobs for longer than the default of 360 minutes - container: - image: rocm/pytorch:latest - options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Show Device Info - run: | - rocminfo | grep gfx - - - name: Uninstall Triton - run: | - pip uninstall -y triton - rm -rf ~/.triton - rm -rf ./triton/python/build - - - name: Install Triton - run: | - pip install triton==3.3.0 - - - name: Show Triton version - run: | - pip show triton - - - name: Build - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install - - - name: Flash Attention Tests - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest tests/test_flash_attn_triton_amd.py::test_flash_attn_output diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml index 2f49567f960..2e3f061c78d 100644 --- a/.github/workflows/amd_tests.yml +++ b/.github/workflows/amd_tests.yml @@ -60,4 +60,6 @@ jobs: - name: AMD Bench run: | - python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_func flash_attn_varlen_func flash_attn_with_kvcache + python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_func + python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_varlen_func + python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_with_kvcache diff --git a/flash_attn/flash_attn_triton_amd/.gitignore b/flash_attn/flash_attn_triton_amd/.gitignore new file mode 100644 index 00000000000..21538fc4e4a --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/.gitignore @@ -0,0 +1,2 @@ +bwd_prefill_fused.py +bwd_prefill_onekernel.py \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bench.py b/flash_attn/flash_attn_triton_amd/bench.py index f2b2e7d11d6..e19de575c8c 100755 --- a/flash_attn/flash_attn_triton_amd/bench.py +++ b/flash_attn/flash_attn_triton_amd/bench.py @@ -5,6 +5,9 @@ import time import argparse import itertools +import logging +import warnings +import datetime import pandas as pd from logging import warning from typing import Dict, List, Literal, Optional, Tuple @@ -73,6 +76,10 @@ "flash_attn_with_kvcache": ["fwd"], } + +# Add a global variable for verbose mode +VERBOSE = False + @dataclass class EnvVariableConfig: key: str @@ -80,7 +87,7 @@ class EnvVariableConfig: backend: Optional[Literal["triton", "ck"]] = None ENV_VARIABLE_CONFIGS : List[EnvVariableConfig] = [ - # EnvVariableConfig(key="BWD_MODE", values=["split", "fused", "jingning"], backend="triton"), + # EnvVariableConfig(key="BWD_MODE", values=["split", "fused_atomics", "fused_no_atomics"], backend="triton"), ] class FunctionConfig: @@ -108,53 +115,6 @@ def __str__(self): def column_name(self): return f"{self}_ms" - - -@lru_cache() -def available_backends(): - available = [] - - # try to load each backend - for backend in ["triton", "ck"]: - try: - # try loading the module with this backend - flash_attn = load_flash_attn_module(backend) - - # if we got here, the backend loaded successfully - available.append(backend) - except Exception as e: - # backend not available, just continue - print(f"Backend {backend} not available. Error: {e}") - - # if no backends available, default to triton - if not available: - raise ValueError("No Backends available") - - return available - -@lru_cache() -def get_fn_params(fn_name): - # get params for fn - packing = get_packing_type(fn_name) - is_varlen = True if "varlen" in fn_name else False - is_fp8 = True if "fp8" in fn_name else False - supported_dtypes = SUPPORTED_DTYPES.get(fn_name, [torch.float16]) # default to float16 if not found - supported_backends = [backend for backend in SUPPORTED_BACKENDS.get(fn_name, ["triton"]) if backend in available_backends()] # default to triton backend - supports_backward = False if fn_name in ["flash_attn_with_kvcache"] else True - supported_modes = SUPPORTED_MODES.get(fn_name, ["fwd"]) - device = "cuda" - - # get supported env configs for each backend - supported_env_configs = {} - for backend in supported_backends: - supported_env_configs[backend] = get_env_value_combinations(backend) - - # check backward pass support - if not supports_backward: - warning(f"{fn_name} does not have a backward pass so benching forward pass only.") - - return is_varlen, is_fp8, packing, supported_dtypes, supported_backends, supported_modes, supported_env_configs, device - def generate_fn_inputs( fn_name: str, BATCH: int, @@ -858,11 +818,12 @@ def get_packing_type(fn_name: str) -> Optional[Literal["kv", "qkv"]]: return packing -def load_flash_attn_module(backend: Literal["triton", "ck"], env_configs: Dict = {}, verbose = False): +def load_flash_attn_module(backend: Literal["triton", "ck"], env_configs: Dict = {}): """ Load the flash_attn module with the specified backend configuration """ - + global VERBOSE + # remove any existing env variables first for key in ENV_FLAGS: if key in os.environ: @@ -881,7 +842,7 @@ def load_flash_attn_module(backend: Literal["triton", "ck"], env_configs: Dict = # add custom env configs add_env_configs(env_configs) - if verbose: + if VERBOSE: # Only print if both local and global verbose are True print(f"Loading flash_attn module with {backend} backend.") # Remove any existing flash_attn modules from sys.modules @@ -894,6 +855,10 @@ def load_flash_attn_module(backend: Literal["triton", "ck"], env_configs: Dict = # Import and return the module import flash_attn + + # disable triton printing from autotuning + if not VERBOSE: + os.environ["TRITON_PRINT_AUTOTUNING"] = "0" return flash_attn @@ -907,11 +872,8 @@ def run_benchmark(func_config: FunctionConfig, input_configs): """ Runs the benchmark for the provided function configuration with the given input configurations. """ - # print new line to seperate benchmark runs - print() - if DEBUG: - print("func_config:", func_config) - + global VERBOSE + # extract function configuration parameters fn_name = func_config.fn_name mode = func_config.mode @@ -919,13 +881,14 @@ def run_benchmark(func_config: FunctionConfig, input_configs): backend = func_config.backend # load flash attention module - flash_attn_module = load_flash_attn_module(backend, func_config.env_configs, verbose=True) + flash_attn_module = load_flash_attn_module(backend, func_config.env_configs) # start timing the benchmark start_time = time.time() - - # print bench fn - print(f"Benchmarking {func_config} ...") + if VERBOSE: + print(f"Benchmarking {func_config} ...") + else: + print(f"Running {fn_name} ({mode}, {backend})...", end='', flush=True) # Setup benchmark configurations bench_configs = [ @@ -966,16 +929,15 @@ def bench_function( ms = triton.testing.do_bench(benchmark_fn, warmup=25, rep=100) return ms - df = bench_function.run(save_path=".", print_data=True, return_df=True)[0] + df = bench_function.run(return_df=True)[0] # set the column name to reflect the function configuration df = df.rename(columns={"Time (ms)": func_config.column_name()}) # calculate and print elapsed time elapsed_time = time.time() - start_time - print(f"Total time for benchmarking {fn_name} in {mode} mode with {dtype}: {elapsed_time:.2f} seconds") - return df + return df, elapsed_time def filter_modes(requested_modes, fn_name, supported_modes_for_fn): modes_to_run = [] @@ -1025,26 +987,88 @@ def get_input_config_set(config_type): return input_configs -def filter_backends(requested_backends, supported_backends, fn_name): +def available_backends(): + """Check which backends are available by trying to load them.""" + available = [] + + for backend in ["triton", "ck"]: + try: + # try loading the module with this backend + load_flash_attn_module(backend) + available.append(backend) + except Exception as e: + # backend not available, just continue + if DEBUG: + print(f"Backend {backend} not available: {e}") + + if not available: + raise ValueError("No backends are available. Please check your flash_attn installation.") + + return available + +# 2. Simplify get_fn_params to remove the backend filtering logic here +@lru_cache() +def get_fn_params(fn_name): + # get params for fn + packing = get_packing_type(fn_name) + is_varlen = True if "varlen" in fn_name else False + is_fp8 = True if "fp8" in fn_name else False + supported_dtypes = SUPPORTED_DTYPES.get(fn_name, [torch.float16]) + supported_backends = SUPPORTED_BACKENDS.get(fn_name, ["triton"]) # just get what the function supports + supports_backward = False if fn_name in ["flash_attn_with_kvcache"] else True + supported_modes = SUPPORTED_MODES.get(fn_name, ["fwd"]) + device = "cuda" + + # get supported env configs for each backend + supported_env_configs = {} + for backend in supported_backends: + supported_env_configs[backend] = get_env_value_combinations(backend) + + # check backward pass support + if not supports_backward: + warning(f"{fn_name} does not have a backward pass so benching forward pass only.") + + return is_varlen, is_fp8, packing, supported_dtypes, supported_backends, supported_modes, supported_env_configs, device + +# 3. Create a new simpler function to validate and filter backends +def validate_backends(requested_backends, supported_backends, fn_name): + """Validate that requested backends are available and supported.""" + # get actually available backends + available = available_backends() + + # determine which backends to use if requested_backends: - selected = [] - for be in requested_backends: - if be in supported_backends: - selected.append(be) - else: - warning( - f"backend '{be}' requested but not supported by " - f"function '{fn_name}'. skipping this back-end." - ) - return selected + # user specified backends - validate them + valid_backends = [] + for backend in requested_backends: + if backend not in available: + warning(f"Backend '{backend}' is not available on this system. Skipping.") + continue + if backend not in supported_backends: + warning(f"Backend '{backend}' is not supported by function '{fn_name}'. Skipping.") + continue + valid_backends.append(backend) + + if not valid_backends: + raise ValueError(f"None of the requested backends {requested_backends} are available and supported for {fn_name}") + + return valid_backends else: - return supported_backends - + # no backends specified - use all available and supported + valid_backends = [b for b in supported_backends if b in available] + + if not valid_backends: + raise ValueError(f"No available backends found for {fn_name}. Function supports {supported_backends} but only {available} are available.") + + return valid_backends +# 4. Update process_args to use the new validate_backends function def process_args(): """ Parses command-line arguments and returns function configs and input configs. """ + global VERBOSE + # create parser parser = argparse.ArgumentParser( prog="Benchmark FlashAttention", @@ -1064,16 +1088,35 @@ def process_args(): type=str, nargs='*', choices=VALID_MODES, - default=None, - help=f"Benchmarking mode(s) to run. If omitted, runs all supported modes for each function.", + default=["fwd", "bwd"], + help=f"Benchmarking mode(s) to run. Default: fwd, bwd", ) parser.add_argument( "--backend", type=str, nargs='*', choices=["triton", "ck"], - default=None, - help="Back-end(s) to run (triton, ck). Omit to run every back-end that is both available and supported by the function.", + default=["triton"], + help="Backend(s) to run. Default: triton", + ) + parser.add_argument( + "--output", + type=str, + choices=["ms", "tflops"], + default="tflops", + help="Output metric type: ms (milliseconds) or tflops (TFLOPS). Default: tflops", + ) + parser.add_argument( + "--format", + type=str, + choices=["csv", "markdown"], + default="csv", + help="Output file format: csv or markdown. Default: csv", + ) + parser.add_argument( + "--verbose", "-v", + action="store_true", + help="Enable verbose output (show autotuning details)", ) # config parser.add_argument("-b", type=int, default=None, help="Batch size") @@ -1087,15 +1130,21 @@ def process_args(): # parse args args = parser.parse_args() + + # Set global verbose flag + VERBOSE = args.verbose # parse function args benchmark_fns = args.benchmark_fn requested_modes = args.mode requested_backends = args.backend + output_type: Literal["ms", "tflops"] = args.output + output_format: Literal["csv", "markdown"] = args.format - # fenerate function configurations and input configurations separately + # generate function configurations and input configurations separately all_function_configs = [] all_input_configs = {} # Maps function config -> input configs + for fn_name in benchmark_fns: is_varlen, is_fp8, packing, supported_dtypes, supported_backends, supported_modes_for_fn, supported_env_configs, device = get_fn_params(fn_name) @@ -1115,10 +1164,7 @@ def process_args(): dropout = args.dropout if args.dropout is not None else 0.0 input_configs = [(batch, hq, hk, sq, sk, d_head, causal, dropout)] else: - if True: - input_configs = get_input_config_set("llama") - else: - input_configs = generate_benchmark_configs(is_varlen, packing) + input_configs = get_input_config_set("llama") # filter by mode modes_to_run = filter_modes(requested_modes, fn_name, supported_modes_for_fn) @@ -1126,12 +1172,11 @@ def process_args(): warning(f"No valid modes to run for function '{fn_name}' based on request and function support. Skipping this function.") continue - # filter by backend - backends_to_run = filter_backends(requested_backends, - supported_backends, - fn_name) - if not backends_to_run: - warning(f"no valid back-ends left for '{fn_name}'. skipping.") + # validate and filter backends + try: + backends_to_run = validate_backends(requested_backends, supported_backends, fn_name) + except ValueError as e: + warning(str(e)) continue # create a function config for each backend and dtype combination @@ -1149,7 +1194,7 @@ def process_args(): all_input_configs[func_config] = fn_inputs - return all_function_configs, all_input_configs + return all_function_configs, all_input_configs, output_type, output_format def check_environment_variables(): for key in ENV_FLAGS: @@ -1202,10 +1247,24 @@ def add_tflops_columns(df: pd.DataFrame, func_cfg: FunctionConfig) -> pd.DataFra df[tf_col] = flops / df[ms_col] * 1e-9 return df +def generate_output_filename(function_configs, output_type, output_format): + # create a timestamp + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + + # simple filename format + base_filename = f"benchmark_{timestamp}" + + if output_format == "csv": + return base_filename + ".csv" + else: # markdown + return base_filename + ".md" + def main(): """ Main function to run benchmarks. """ + global VERBOSE + # check environment variables check_environment_variables() @@ -1213,19 +1272,37 @@ def main(): total_start_time = time.time() # process args to get function configs and input configs - function_configs, all_input_configs = process_args() + function_configs, all_input_configs, output_type, output_format = process_args() + + # Print summary of what will be benchmarked (always show this) + print(f"\nBenchmarking {len(function_configs)} configuration(s):") + unique_fns = set(fc.fn_name for fc in function_configs) + print(f" Functions: {', '.join(unique_fns)}") + unique_backends = set(fc.backend for fc in function_configs) + print(f" Backends: {', '.join(unique_backends)}") + unique_modes = set(fc.mode for fc in function_configs) + print(f" Modes: {', '.join(unique_modes)}") + print() # run benchmarks for each function configuration combined_ms_df = None combined_tf_df = None input_cols = ["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"] - for func_config in function_configs: + + for i, func_config in enumerate(function_configs, 1): + # Progress indicator + if not VERBOSE: + print(f"[{i}/{len(function_configs)}] ", end='') + # run benchmark with the input configs for this function config input_configs = all_input_configs[func_config] - df = run_benchmark(func_config, input_configs) - df = add_tflops_columns(df, func_config) + df, elapsed_time = run_benchmark(func_config, input_configs) + + if VERBOSE: + print(f"Total time for benchmarking {func_config.fn_name} in {func_config.mode} mode with {func_config.dtype}: {elapsed_time:.2f} seconds") # add to combined table + df = add_tflops_columns(df, func_config) ms_cols = [c for c in df.columns if c.endswith('_ms')] tf_cols = [c for c in df.columns if c.endswith('_tflops')] @@ -1244,7 +1321,7 @@ def main(): # print total time for all benchmarks total_elapsed_time = time.time() - total_start_time - print(f"Total time for all benchmarks: {total_elapsed_time:.2f} seconds") + print(f"Total benchmark time: {total_elapsed_time:.1f} seconds") # save combined data and make comparisons if we have multiple function configs has_multiple_func_configs = False # len(function_configs) > 1 @@ -1282,19 +1359,33 @@ def main(): print(f"Comparison Results (triton vs ck):") print(f"Ratio values: values > 1 mean triton is faster (by that factor), values < 1 mean ck is faster") - if combined_ms_df is not None: - print("\nCombined wall‑time (ms) table:") - print(combined_ms_df) - combined_ms_df.to_csv("benchmark_ms.csv", index=False) - with open("benchmark_ms.md", 'w') as f: - f.write(combined_ms_df.to_markdown(index=False, floatfmt=".2f")) - - if combined_tf_df is not None: - print("\nCombined throughput (TFLOPs) table:") - print(combined_tf_df) - combined_tf_df.to_csv("benchmark_tflops.csv", index=False) - with open("benchmark_tflops.md", 'w') as f: - f.write(combined_tf_df.to_markdown(index=False, floatfmt=".2f")) + # output based on selected metric + if output_type == "ms": + if combined_ms_df is not None: + filename = generate_output_filename(function_configs, "ms", output_format) + print(f"\nCombined wall-time (ms) table:") + print(combined_ms_df) + + if output_format == "csv": + combined_ms_df.to_csv(filename, index=False) + print(f"Results saved to: {filename}") + else: # markdown + with open(filename, 'w') as f: + f.write(combined_ms_df.to_markdown(index=False, floatfmt=".2f")) + print(f"Results saved to: {filename}") + else: # output_type == "tflops" + if combined_tf_df is not None: + filename = generate_output_filename(function_configs, "tflops", output_format) + print(f"\nCombined throughput (TFLOPs) table:") + print(combined_tf_df) + + if output_format == "csv": + combined_tf_df.to_csv(filename, index=False) + print(f"Results saved to: {filename}") + else: # markdown + with open(filename, 'w') as f: + f.write(combined_tf_df.to_markdown(index=False, floatfmt=".2f")) + print(f"Results saved to: {filename}") if __name__ == "__main__": main() \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill.py b/flash_attn/flash_attn_triton_amd/bwd_prefill.py deleted file mode 100644 index 44e2c294b0d..00000000000 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill.py +++ /dev/null @@ -1,814 +0,0 @@ -from typing import Literal, Optional -import torch -import triton -import triton.language as tl -from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, get_shapes_from_layout, get_strides_from_layout, is_fp8, write_dropout_mask, create_dropout_mask - -# TODO: move this into utils.py so it's shared among kernels -# NOTE: triton fails to import tl.constexprs so create them here for the file -tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) -tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) - -@triton.jit -def _bwd_preprocess( - Out, - DO, - Delta, - stride_oz, stride_oh, stride_om, stride_ok, - stride_doz, stride_doh, stride_dom, stride_dok, - stride_deltaz, stride_deltah, stride_deltam, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - DESCALE_do, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - N_CTX_Q: tl.constexpr, - Z: tl.constexpr, - H: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, -): - pid_bh = tl.program_id(0) - pid_m = tl.program_id(1) - - # Compute batch and head indices - off_z = pid_bh // H - off_h = pid_bh % H - - if IS_VARLEN: - # Compute sequence lengths for the current batch - q_start = tl.load(cu_seqlens_q + off_z) - q_end = tl.load(cu_seqlens_q + off_z + 1) - k_start = tl.load(cu_seqlens_k + off_z) - k_end = tl.load(cu_seqlens_k + off_z + 1) - - # Compute actual sequence lengths - N_CTX_Q = q_end - q_start - N_CTX_K = k_end - k_start - else: - q_start = 0 - k_start = 0 - N_CTX_Q = max_seqlen_q - N_CTX_K = max_seqlen_k - - off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_d = tl.arange(0, BLOCK_DMODEL) - - # create masks - mask_m = off_m < N_CTX_Q - mask_d = off_d < ACTUAL_BLOCK_DMODEL - - # compute offsets - o_offset = Out + off_z * stride_oz + off_h * stride_oh + q_start * stride_om - do_offset = DO + off_z * stride_oz + off_h * stride_oh + q_start * stride_om - - # compute pointers - out_ptrs = o_offset + off_m[:, None] * stride_om + off_d[None, :] * stride_ok - do_ptrs = do_offset + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok - - # load - o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0) - do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0) - - # compute delta - if IS_FP8: - stride_descale_q_z = H - descale_do = tl.load(DESCALE_do + off_z * stride_descale_q_z + off_h) - - # NOTE: do is scaled into the fp8 range and o is in fp8 but should be in the same scale as fp32 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - - # write-back delta - delta_offset = Delta + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam - delta_ptrs = delta_offset + off_m * stride_deltam - tl.store(delta_ptrs, delta, mask=mask_m) - - -@triton.jit -def _bwd_kernel_one_col_block( - Q, - K, - V, - sm_scale, - Out, - DO, - DQ, - DK, - DV, - L, - D, - q_offset, - k_offset, - v_offset, - do_offset, - dq_offset, - dk_offset, - dv_offset, - l_offset, - delta_offset, - dropout_offset, - stride_dq_all, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vn, - stride_vk, - stride_deltaz, - stride_deltah, - stride_deltam, - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, - N_CTX_Q, - N_CTX_K, - start_n, - num_block_m, - num_block_n, - dropout_p, - philox_seed, - batch_philox_offset, - descale_q, - descale_k, - descale_v, - descale_do, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - SEQUENCE_PARALLEL: tl.constexpr, - CAUSAL: tl.constexpr, - DROPOUT: tl.constexpr, - USE_EXP2: tl.constexpr, - GROUP_SIZE: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - if CAUSAL: - # TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M - lo = 0 - else: - lo = 0 - - # initialize col and head offsets - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - - # masks - mask_n = offs_n < N_CTX_K - mask_d = offs_d < ACTUAL_BLOCK_DMODEL - kv_mask = mask_n[:, None] & mask_d[None, :] - - - # initialize grad accumulators - dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) - dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) - - # load k and v once per column block - k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk - v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk - k = tl.load(k_ptrs, mask=kv_mask, other=0.0) - kT = tl.trans(k) - vT = tl.trans(tl.load(v_ptrs, mask=kv_mask, other=0.0)) - - # loop over rows - for start_m in range(lo, num_block_m): - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - dq_ptrs = dq_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - do_ptrs = do_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - - # update mask as row block changes - mask_m = offs_m < N_CTX_Q - q_mask = mask_m[:, None] & mask_d[None, :] - - # load q, k, v, do on-chip - q = tl.load(q_ptrs, mask=q_mask, other=0.0) - do = tl.load(do_ptrs, mask=q_mask, other=0.0) - - # recompute p = softmax(qk, dim=-1).T - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - if IS_FP8: - qk += (tl.dot(q, kT) * descale_q * descale_k) - else: - qk += tl.dot(q, kT) - - if CAUSAL: - col_offset = N_CTX_Q - N_CTX_K - causal_mask = offs_m[:, None] >= (col_offset + offs_n[None, :]) - qk = tl.where(causal_mask, qk, float("-inf")) - - l_ptrs = l_offset + offs_m * stride_deltam - l_i = tl.load(l_ptrs, mask=mask_m) - - # compute p - if USE_EXP2: - RCP_LN2: tl.constexpr = 1.4426950408889634 - qk *= sm_scale * RCP_LN2 - l_i *= RCP_LN2 - p = tl.math.exp2(qk - l_i[:, None]) - else: - qk *= sm_scale - p = tl.math.exp(qk - l_i[:, None]) - - # mask block in the cases where the data is smaller the block size - p_mask = mask_m[:, None] & mask_n[None, :] - p = tl.where(p_mask, p, 0.0) - - if DROPOUT: - # NOTE: must create a new var p_drop to prevent p (which is used later to compute ds) from changing - philox_offset = batch_philox_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn - # print("philox_seed:", philox_seed) - # print("philox_offset:", philox_offset) - if tl_DROPOUT_USE_PYTORCH: - dropout_ptrs = dropout_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn - dropout_mask = tl.load(dropout_ptrs, mask=p_mask) - else: - rand_vals = tl.rand(philox_seed, philox_offset) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1/ (1 - dropout_p) - - if tl_DROPOUT_DUMP: - dropout_ptrs = dropout_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn - tl.store(dropout_ptrs, dropout_mask, mask=p_mask) - - # apply dropout mask - p_drop = tl.where(dropout_mask, p, 0.0) - p_drop_scaled = p_drop * dropout_scale - - # compute dv - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(p_drop_scaled, FP8_MAX) - dv += (tl.dot(tl.trans(p_drop_scaled * scale_p_dropout).to(do.type.element_ty), do) * descale_p_dropout * descale_do) - else: - dv += tl.dot(tl.trans(p_drop_scaled).to(do.type.element_ty), do) - - # compute dp - if IS_FP8: - dp_drop_scaled = (tl.dot(do, vT) * descale_do * descale_v) - else: - dp_drop_scaled = tl.dot(do, vT) - dp = tl.where(dropout_mask, dp_drop_scaled, 0.0) * dropout_scale - else: - - # compute dv - if IS_FP8: - scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - dv += (tl.dot(tl.trans(p * scale_p).to(do.type.element_ty), do) * descale_p * descale_do) - else: - dv += tl.dot(tl.trans(p).to(do.type.element_ty), do) - - # compute dp - if IS_FP8: - dp = (tl.dot(do, vT) * descale_do * descale_v) - else: - dp = tl.dot(do, vT) - - - # load delta - delta_ptrs = delta_offset + offs_m * stride_deltam - delta_i = tl.load(delta_ptrs, mask=mask_m) - - # compute ds - dscores_scaled = (p * (dp - delta_i[:, None])) - ds = dscores_scaled * sm_scale - ds = tl.where(p_mask, ds, 0.0) - - # compute descale_ds - if IS_FP8: - scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) - else: - scale_ds, descale_ds = 1.0, 1.0 - - # compute dk - if IS_FP8: - dk += (tl.dot(tl.trans(ds * scale_ds).to(q.type.element_ty), q) * descale_ds * descale_q) - else: - dk += tl.dot(tl.trans(ds).to(q.type.element_ty), q) - - # compute dq - if SEQUENCE_PARALLEL: - if IS_FP8: - dq = (tl.dot((ds * scale_ds).to(k.type.element_ty), k) * descale_ds * descale_k) - else: - dq = tl.dot(ds.to(k.type.element_ty), k) - else: - dq = tl.load(dq_ptrs, mask=q_mask, other=0.0) - if IS_FP8: - dq += (tl.dot((ds * scale_ds).to(k.type.element_ty), k) * descale_ds * descale_k) - else: - dq += tl.dot(ds.to(k.type.element_ty), k) - tl.store(dq_ptrs, dq.to(Q.dtype.element_ty), mask=q_mask) - - # write-back dv and dk - dk_ptrs = dk_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk - dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk - - # write-back - if GROUP_SIZE != 1: - # use atomic_add to properly accumulate gradients from multiple query heads - tl.atomic_add(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask) - tl.atomic_add(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask) - else: - tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask) - tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask) - -@triton.jit -def _bwd_kernel( - Q, - K, - V, - sm_scale, - Out, - DO, - DQ, - DK, - DV, - L, - Delta, - Dropout_mask, - DESCALE_q, - DESCALE_k, - DESCALE_v, - DESCALE_do, - stride_dq_all, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vn, - stride_vk, - stride_deltaz, - stride_deltah, - stride_deltam, - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, - Z, - HQ, - HK, - num_block_m, - num_block_n, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset_base, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - SEQUENCE_PARALLEL: tl.constexpr, - CAUSAL: tl.constexpr, - DROPOUT: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_VARLEN: tl.constexpr, - GROUP_SIZE: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - # program ids - off_zh = tl.program_id(0) - if SEQUENCE_PARALLEL: - start_n = tl.program_id(1) - off_z = off_zh // HQ - off_hq = off_zh % HQ - - # check if GQA/MQA - if GROUP_SIZE != 1: - off_hk = off_hq // GROUP_SIZE - else: - off_hk = off_hq - - if IS_VARLEN: - # Compute sequence lengths for the current batch - q_start = tl.load(cu_seqlens_q + off_z) - q_end = tl.load(cu_seqlens_q + off_z + 1) - k_start = tl.load(cu_seqlens_k + off_z) - k_end = tl.load(cu_seqlens_k + off_z + 1) - - # Compute actual sequence lengths - N_CTX_Q = q_end - q_start - N_CTX_K = k_end - k_start - else: - q_start = 0 - k_start = 0 - N_CTX_Q = max_seqlen_q - N_CTX_K = max_seqlen_k - - # input tensor offsets - q_offset = Q + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm - k_offset = K + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn - v_offset = V + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn - do_offset = DO + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm - l_offset = L + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam - delta_offset = Delta + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam - - if DROPOUT: - batch_philox_offset = philox_offset_base + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm - dropout_offset = Dropout_mask + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm - else: - batch_philox_offset = 0 - dropout_offset = 0 - - if IS_FP8: - stride_descale_q_z = HQ - stride_descale_kv_z = HK - - descale_q = tl.load(DESCALE_q + off_z * stride_descale_q_z + off_hq) - descale_k = tl.load(DESCALE_k + off_z * stride_descale_kv_z + off_hk) - descale_v = tl.load(DESCALE_v + off_z * stride_descale_kv_z + off_hk) - descale_do = tl.load(DESCALE_do + off_z * stride_descale_q_z + off_hq) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # output tensor offsets - dk_offset = DK + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn - dv_offset = DV + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn - if SEQUENCE_PARALLEL: - dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm - else: - dq_offset = DQ + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm - - # inner loop - if SEQUENCE_PARALLEL: - _bwd_kernel_one_col_block( - Q, - K, - V, - sm_scale, - Out, - DO, - DQ, - DK, - DV, - L, - Delta, - q_offset, - k_offset, - v_offset, - do_offset, - dq_offset, - dk_offset, - dv_offset, - l_offset, - delta_offset, - dropout_offset, - stride_dq_all, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vn, - stride_vk, - stride_deltaz, - stride_deltah, - stride_deltam, - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, - N_CTX_Q, - N_CTX_K, - start_n, - num_block_m, - num_block_n, - dropout_p, - philox_seed, - batch_philox_offset, - descale_q, - descale_k, - descale_v, - descale_do, - BLOCK_M=BLOCK_M, - BLOCK_DMODEL=BLOCK_DMODEL, - ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, - BLOCK_N=BLOCK_N, - SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, - CAUSAL=CAUSAL, - DROPOUT=DROPOUT, - USE_EXP2=USE_EXP2, - GROUP_SIZE=GROUP_SIZE, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX - ) - else: - for start_n in range(0, num_block_n): - _bwd_kernel_one_col_block( - Q, - K, - V, - sm_scale, - Out, - DO, - DQ, - DK, - DV, - L, - Delta, - q_offset, - k_offset, - v_offset, - do_offset, - dq_offset, - dk_offset, - dv_offset, - l_offset, - delta_offset, - dropout_offset, - stride_dq_all, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vn, - stride_vk, - stride_deltaz, - stride_deltah, - stride_deltam, - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, - N_CTX_Q, - N_CTX_K, - start_n, - num_block_m, - num_block_n, - dropout_p, - philox_seed, - batch_philox_offset, - descale_q, - descale_k, - descale_v, - descale_do, - BLOCK_M=BLOCK_M, - BLOCK_DMODEL=BLOCK_DMODEL, - ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, - BLOCK_N=BLOCK_N, - SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, - CAUSAL=CAUSAL, - DROPOUT=DROPOUT, - USE_EXP2=USE_EXP2, - GROUP_SIZE=GROUP_SIZE, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX - ) - - -# NOTE: smaller blocks have lower accuracy. more accumulation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumulation errors but no oom. -def attention_prefill_backward_triton_impl( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - layout: Literal["bshd", "bhsd", "thd"], - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - use_exp2: bool, - sequence_parallel: bool = True, - # fp8 - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None, -): - if DEBUG: - print() - print("attention_prefill_backward_triton_impl") - print("do:", do, do.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("o:", o, o.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape if dq is not None else None) - print("dk:", dk, dk.shape if dk is not None else None) - print("dv:", dv, dv.shape if dv is not None else None) - print("sm_scale:", sm_scale) - print("alibi_slopes:", alibi_slopes) - print("causal:", causal) - print("layout:", layout) - print("cu_seqlens_q:", cu_seqlens_q) - print("cu_seqlens_k:", cu_seqlens_k) - print("max_seqlen_q:", max_seqlen_q) - print("max_seqlen_k:", max_seqlen_k) - print("dropout_p:", dropout_p) - print("philox_seed:", philox_seed) - print("philox_offset:", philox_offset) - print("use_exp2:", use_exp2) - print("sequence_parallel:", sequence_parallel) - print("descale_q:", descale_q) - print("descale_k:", descale_k) - print("descale_v:", descale_v) - print("descale_do:", descale_do) - - IS_FP8 = is_fp8(q) - if IS_FP8: - FP8_MAX=torch.finfo(q.dtype).max - else: - FP8_MAX=None - - # make contiguous - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - softmax_lse = softmax_lse.contiguous() - - # get strides and shape - batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shapes_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) - q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) - stride_qz, stride_qh, stride_qm, stride_qk = q_strides - stride_kz, stride_kh, stride_kn, stride_kk = k_strides - stride_vz, stride_vh, stride_vn, stride_vk = v_strides - stride_oz, stride_oh, stride_om, stride_ok = o_strides - is_varlen = layout == "thd" - group_size = nheads_q // nheads_k - use_dropout = (dropout_p > 0.0) - - # FIXME: some configs lead to oom for some reason when using 64 x 64 blocks - if max_seqlen_q <= 32 or max_seqlen_k <= 32: - BLOCK_M = 32 - BLOCK_N = 32 - else: - BLOCK_M = 64 - BLOCK_N = 64 - - if DEBUG: - print("BLOCK_M:", BLOCK_M) - print("BLOCK_N:", BLOCK_N) - - num_warps = 4 # NOTE: original is 8. changing it to 1 caused issues be careful - num_stages = 1 - waves_per_eu = 1 - - # divide up the problem - num_blocks_m = triton.cdiv(max_seqlen_q, BLOCK_M) - num_blocks_n = triton.cdiv(max_seqlen_k, BLOCK_N) - - # get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() - padded_d_model = max(padded_d_model, 16) - BLOCK_DMODEL = padded_d_model - ACTUAL_BLOCK_DMODEL = head_size - - do = do.contiguous() - - # deal with dq - if sequence_parallel: - dq = dq.unsqueeze(0).repeat(num_blocks_n, *([1] * len(q.shape))) # we do repeat instead of expand because we need to write data so views are not enough - stride_dq_all = dq.stride()[0] - - # assert contiguous - assert do.is_contiguous() - assert q.is_contiguous() - assert k.is_contiguous() - assert v.is_contiguous() - assert o.is_contiguous() - assert softmax_lse.is_contiguous() - - # init delta - delta = torch.zeros_like(softmax_lse) - if is_varlen: - stride_deltam, stride_deltah = delta.stride() - stride_deltaz = 0 - else: - stride_deltaz, stride_deltah, stride_deltam = delta.stride() - - # dropout mask tensor for debugging. We dump the dropout mask created in the kernel for testing - if use_dropout: - if DROPOUT_USE_PYTORCH: - dropout_mask = create_dropout_mask(dropout_p, (batch, nheads_q, max_seqlen_q, max_seqlen_k), seed = philox_seed) - else: - dropout_mask = torch.zeros((batch, nheads_q, max_seqlen_q, max_seqlen_k), device=q.device, - dtype=torch.float32) - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn = (dropout_mask.stride(0), dropout_mask.stride(1), dropout_mask.stride(2), dropout_mask.stride(3)) - else: - dropout_mask = None - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn = (0, 0 , 0 , 0) - - - _bwd_preprocess[(batch * nheads_q, num_blocks_m)]( - o, - do, - delta, - stride_oz, stride_oh, stride_om, stride_ok, - stride_oz, stride_oh, stride_om, stride_ok, # FIXME: don't share strides with derivatives this was causing a lot of issues - stride_deltaz, stride_deltah, stride_deltam, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - descale_do, - BLOCK_M=BLOCK_M, - BLOCK_DMODEL=BLOCK_DMODEL, - ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, - N_CTX_Q=max_seqlen_q, - Z=batch, - H=nheads_q, - IS_VARLEN=is_varlen, - IS_FP8=IS_FP8 - ) - - if DEBUG: - print("delta:", delta, delta.shape) - print("group_size:", group_size) - - _bwd_kernel[(batch * nheads_q, num_blocks_n if sequence_parallel else 1)]( - q, - k, - v, - sm_scale, - o, - do, - dq, - dk, - dv, - softmax_lse, - delta, - dropout_mask, - descale_q, - descale_k, - descale_v, - descale_do, - stride_dq_all, - stride_qz, stride_qh, stride_qm, stride_qk, # FIXME: don't share strides with derivatives this was causing a lot of issues - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_deltaz, stride_deltah, stride_deltam, - stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, - batch, - nheads_q, - nheads_k, - num_blocks_m, - num_blocks_n, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, philox_seed, philox_offset, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_DMODEL=BLOCK_DMODEL, - ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, - SEQUENCE_PARALLEL=sequence_parallel, - CAUSAL=causal, - DROPOUT=use_dropout, - USE_EXP2=use_exp2, - num_warps=num_warps, - num_stages=num_stages, - waves_per_eu = waves_per_eu, - IS_VARLEN=is_varlen, - GROUP_SIZE=group_size, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX - ) - - if sequence_parallel: - dq = dq.sum(dim=0) - - if DEBUG: - print("attention_prefill_backward_triton_impl outputs") - print("dv:", dv, dv.shape) - print("dk:", dk, dk.shape) - print("dq:", dq, dq.shape) - if use_dropout: - print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None) - print("dropout_fraction bwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item()) - write_dropout_mask(dropout_mask, "dropout_mask_bwd") - - return delta diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py similarity index 55% rename from flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py rename to flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py index af3f8790026..e969a3770b8 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py @@ -1,793 +1,10 @@ import torch import triton import triton.language as tl +from flash_attn.flash_attn_triton_amd.utils import compute_fp8_scaling_factors from typing import Optional, Tuple -@triton.jit -def cdiv_fn(x, y): - return (x + y - 1) // y - -@triton.jit -def compute_fp8_scaling_factors(x, fp8_max: tl.constexpr): - # compute fp8 scaling and descaling factor for a block - x_amax = tl.max(tl.abs(x)) # NOTE: abs deals with negative values - x_amax = tl.where(x_amax <= 1e-9, 1e-9, x_amax) - scale_x = fp8_max / x_amax - descale_x = x_amax / fp8_max - return scale_x, descale_x - -def is_fp8(x): - if x.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}: - if arch_supports_fp8(): - return True - else: - raise RuntimeError("This device does not support fp8") - else: - return False - - -def cast_to_fp8( - x: torch.Tensor, - fp8_dtype, - layout, - clamp_val=1e-9, -): - if len(x.shape) != 4: - raise ValueError(f"'bshd' tensor should have shape [batch, seqlen, heads, dim], got {x.shape}") - reduce_dims = (1, 3) # seq_len and dim dimensions - - # Compute the absolute max along reduce_dims, clamped to avoid 0-scale - x_abs_max = x.abs().amax(dim=reduce_dims) - x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) - - # Unsqueeze back to a shape suitable for broadcast - unsqueeze_dims = sorted(reduce_dims) - for d in unsqueeze_dims: - x_abs_max = x_abs_max.unsqueeze(d) - - # compute scale and descale - fp8_max = torch.finfo(fp8_dtype).max - scale = fp8_max / x_abs_max - descale_factor = x_abs_max / fp8_max - - # cast to FP8, optionally setting requires_grad - x_fp8 = (x * scale).to(fp8_dtype) - - return x_fp8, descale_factor - - -def cast_varlen_to_fp8( - x: torch.Tensor, - fp8_dtype: torch.dtype, - cu_seqlens, - clamp_val: float = 1e-9, -) -> tuple[torch.Tensor, torch.Tensor]: - # validate tensor shape - if len(x.shape) != 3: - raise ValueError(f"tensor should have shape [total_seqlen, heads, dim], got {x.shape}") - num_heads = x.shape[1] - - # Get batch size from cu_seqlens - batch = cu_seqlens.shape[0] - 1 - fp8_max = torch.finfo(fp8_dtype).max - - # Compute scale and descale factors per sequence - x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) - descale_factors = torch.zeros((batch, num_heads), device=x.device, dtype=torch.float32) - - for i in range(batch): - start = cu_seqlens[i] - end = cu_seqlens[i + 1] - x_slice = x[start:end] # Slice for current sequence - - # Standard tensor (0: seq_len, 2: head_dim) - x_abs_max = x_slice.abs().amax(dim=(0, 2)) # [heads] - - # apply minimum clamping - x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) - - # compute scale and descale factors - scale_i = fp8_max / x_abs_max - descale_i = x_abs_max / fp8_max - - # store descale factors - descale_factors[i, :] = descale_i - - scale_reshape = scale_i.reshape(1, num_heads, 1) - - # scale and cast to FP8 - x_fp8[start:end] = (x_slice * scale_reshape).to(fp8_dtype) - - return x_fp8, descale_factors - - -#TODO Move this to a common folder. Will need to add future arch list -def get_arch(): - return triton.runtime.driver.active.get_current_target().arch - -def is_hip(): - return triton.runtime.driver.active.get_current_target().backend == "hip" - -def arch_supports_fp8(): - return is_hip() and get_arch() in ('gfx942') - -@triton.jit -def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): - if offset_first is not None and offset_second is not None: - mask = (offset_first[:, None] < boundary_first) & \ - (offset_second[None, :] < boundary_second) - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_first is not None: - mask = offset_first[:, None] < boundary_first - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_second is not None: - mask = offset_second[None, :] < boundary_second - tensor = tl.load(ptrs, mask=mask, other=0.0) - else: - tensor = tl.load(ptrs) - return tensor - -@triton.jit -def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): - # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix - # for casual mask we want something like this where (1 is kept and 0 is masked) - # seqlen_q = 2 and seqlen_k = 5 - # 1 1 1 1 0 - # 1 1 1 1 1 - # seqlen_q = 5 and seqlen_k = 2 - # 0 0 - # 0 0 - # 0 0 - # 1 0 - # 1 1 - # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal - # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False - # 1. offs_m[:,None] = [[0], - # [1], - # 2. offs_m[:,None] + seqlen_k = [[5], - # [6], - # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], - # [4], - # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], - # [4], [ 4, 3, 2, 1, 0]] - # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], - # [ -4, -3, -2, -1, 0]], - relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - if transpose: - return alibi_block.T - else: - return alibi_block - -@triton.jit -def _attn_fwd_inner( - acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - stride_kn, - stride_vk, - stride_sn, - start_m, - seqlen_k, - seqlen_q, - dropout_p, - sd_mask_ptrs, - dropout_mask_ptrs, - philox_seed, - philox_ptrs, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - alibi_slope, - descale_q, - descale_k, - descale_v, - OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DMODEL_POW2: tl.constexpr, - SM_SCALE: tl.constexpr, - IS_CAUSAL: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_SCORES: tl.constexpr, - PADDED_HEAD: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - RCP_LN2: tl.constexpr = 1.4426950408889634 - - # loop over k, v, and update accumulator - - for start_n in range(block_min, block_max, BLOCK_N): - # For padded blocks, we will overrun the tensor size if - # we load all BLOCK_N. For others, the blocks are all within range. - if MASK_STEPS: - k_offs_n = start_n + tl.arange(0, BLOCK_N) - else: - k_offs_n = None - k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL_POW2) - k = load_fn(k_ptrs, k_offs_k, k_offs_n, BLOCK_DMODEL, seqlen_k) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - # We start from end of seqlen_k so only the first iteration would need - # to be checked for padding if it is not a multiple of block_n - # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. - # last step might get wasted but that is okay. check if this masking works For - # that case. - if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32) - size_n = start_n + OFFS_N[None, :] - mask = size_n < boundary_m[:, None] - qk = tl.where(mask, qk, float("-inf")) - - # compute masks - q_mask = (OFFS_M[:, None] < seqlen_q) - k_mask = ((start_n + tl.arange(0, BLOCK_N))[None, :] < seqlen_k) - p_mask = q_mask & k_mask - - # -- compute qk ---- - if IS_FP8: - qk += (tl.dot(q, k) * descale_q * descale_k) - else: - qk += tl.dot(q, k) - qk_scaled = qk * SM_SCALE - if IS_CAUSAL: - causal_boundary = start_n + offs_n_causal - causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) - - if alibi_slope is not None: - # Compute the global position of each token within the sequence - global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - global_n_positions = start_n + tl.arange(0, BLOCK_N) - alibi_block = compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, global_m_positions, - global_n_positions) - qk_scaled += alibi_block - # get max scores so far - m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) - - # scale and subtract max - q_shifted = qk_scaled - m_ij[:, None] - - # Compute scaled QK and softmax probabilities - p = tl.math.exp2(q_shifted * RCP_LN2) - - # CAVEAT: Must update l_ij before applying dropout - l_ij = tl.sum(p, 1) - if ENABLE_DROPOUT: - rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance - dropout_mask = rng_output > dropout_p - tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask) - - # return scores with negative values for dropped vals - sd_mask = tl.where(dropout_mask, p, -p) - tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) - - # apply dropout mask in place - p = tl.where(dropout_mask, p, 0.0) - elif RETURN_SCORES: - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - tl.store(sd_mask_ptrs, p, mask=p_mask) - - # -- update output accumulator -- - # alpha is an adjustment factor for acc and li as we loop and find new maxes - # store the diff in maxes to adjust acc and li as we discover new maxes - m_diff = m_i - m_ij - alpha = tl.math.exp2(m_diff * RCP_LN2) - acc = acc * alpha[:, None] - v = load_fn(v_ptrs, k_offs_n, k_offs_k, seqlen_k, BLOCK_DMODEL) - # -- update m_i and l_i - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - - if IS_FP8: - scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) - else: - acc += tl.dot(p.to(v.type.element_ty), v) - - k_ptrs += BLOCK_N * stride_kn - v_ptrs += BLOCK_N * stride_vk - if RETURN_SCORES: - sd_mask_ptrs += BLOCK_N * stride_sn - - if ENABLE_DROPOUT: - dropout_mask_ptrs += BLOCK_N * stride_sn - philox_ptrs += BLOCK_N * stride_sn - - return acc, l_i, m_i - - -@triton.jit -def _attn_fwd(q_ptr: torch.Tensor, - k_ptr: torch.Tensor, - v_ptr: torch.Tensor, - descale_q_ptr: torch.Tensor, - descale_k_ptr: torch.Tensor, - descale_v_ptr: torch.Tensor, - out_ptr: torch.Tensor, - alibi_slopes_ptr: torch.Tensor, - s_dmask_ptr: torch.Tensor, - dropout_mask_ptr: torch.Tensor, - softmax_lse_ptr: torch.Tensor, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_oz, stride_oh, stride_om, stride_on, - stride_alibi_z, stride_alibi_h, - stride_sd_z, stride_sd_h, stride_sd_m, stride_sd_n, - stride_lse_z, stride_lse_h, stride_lse_m, - sm_scale, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset, - SEQLEN_Q: tl.constexpr, - SEQLEN_K: tl.constexpr, - IS_CAUSAL: tl.constexpr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DMODEL_POW2: tl.constexpr, - RETURN_SCORES: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - VARLEN: tl.constexpr, -): - #calculate offsets - off_z = tl.program_id(0) #batch - off_q_head = tl.program_id(1) #num_q_heads - start_m = tl.program_id(2) #seqlen_q - - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL_POW2) - - if VARLEN: - cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) - cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. - if start_m * BLOCK_M > seqlen_q: - return - cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) - cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) - seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start - else: - cu_seqlens_q_start = 0 - cu_seqlens_k_start = 0 - seqlen_q = SEQLEN_Q - seqlen_k = SEQLEN_K - - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - - # Now we compute whether we need to exit early due to causal masking. - # This is because for seqlen_q > seqlen_k, M rows of the attn scores - # are completely masked, resulting in 0s written to the output, and - # inf written to LSE. We don't need to do any GEMMs in this case. - # This block of code determines what N is, and if this WG is operating - # on those M rows. - if (IS_CAUSAL): - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which means - # the causal mask boundary is bottom right aligned, and ends at either - # the top edge (seqlen_q < seqlen_k) or left edge. - - # This captures the decrease in n_blocks if we have a rectangular attn matrix - n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) - - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - - # If we have no blocks after adjusting for seqlen deltas, this WG is part of - # the blocks that are all 0. We exit early. - if n_blocks <= 0: - offs_out = (off_z * stride_oz + - off_q_head * stride_oh + - cu_seqlens_q_start * stride_om + - offs_m[:, None] * stride_om + - offs_d[None, :] * stride_on) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=out_ptr.type.element_ty) - out_mask = (offs_m[:, None] < seqlen_q) & (offs_d < BLOCK_DMODEL) - tl.store(out_ptr + offs_out, acc, mask=out_mask) - - if softmax_lse_ptr is not None: - offs_lse = (off_z * stride_lse_z + - off_q_head * stride_lse_h + - cu_seqlens_q_start * stride_lse_m + - offs_m*stride_lse_m - ) - lse_mask = offs_m < SEQLEN_Q - lse = tl.full([BLOCK_M], value=0.0, dtype=tl.float32) - tl.store(softmax_lse_ptr + offs_lse, lse, mask=lse_mask) - # TODO: Should dropout and return encoded softmax be handled here too? - - return - - grp_sz:tl.constexpr = NUM_Q_HEADS // NUM_K_HEADS - if grp_sz != 1: #Grouped Query Attention - off_k_head = off_q_head // grp_sz - else: - off_k_head = off_q_head - - #q,k,v offsets - q_offs = (off_z * stride_qz + - off_q_head * stride_qh + - cu_seqlens_q_start * stride_qm + - offs_m[:, None] * stride_qm + offs_d[None, :]*stride_qk - ) - q_ptrs = q_ptr + q_offs - - k_offs = (off_z * stride_kz + - off_k_head * stride_kh + - cu_seqlens_k_start * stride_kn + - offs_d[:, None] * stride_kk + offs_n[None, :]*stride_kn - ) - k_ptrs = k_ptr + k_offs - - v_offs = (off_z * stride_vz + - off_k_head * stride_vh + - cu_seqlens_k_start * stride_vn + - offs_n[:, None] * stride_vn + offs_d[None, :]*stride_vk - ) - v_ptrs = v_ptr + v_offs - - #alibi slopes - if alibi_slopes_ptr is not None: - alibi_offs = off_z * stride_alibi_z + off_q_head * stride_alibi_h - alibi_slope = tl.load(alibi_slopes + alibi_offs) - else: - alibi_slope = None - - #s_dmask (return_scores) - if s_dmask_ptr is not None: - s_dmask_offs = (off_z * stride_sd_z + - off_q_head * stride_sd_h + - offs_m[:, None] * stride_sd_m + - offs_n[None, :] * stride_sd_n - ) - s_dmask_ptrs = s_dmask_ptr + s_dmask_offs - else: - s_dmask_ptrs = None - - #dropout - if dropout_mask_ptr is not None: - dropout_mask_offs = (off_z * stride_sd_z + - off_q_head * stride_sd_h + - offs_m[:, None] * stride_sd_m + - offs_n[None, :] * stride_sd_n - ) - dropout_mask_ptrs = dropout_mask_ptr + dropout_mask_offs - philox_ptrs = (philox_offset + - off_z * stride_sd_z + - off_q_head * stride_sd_h + - offs_m[:, None] * stride_sd_m + - offs_n[None, :] * stride_sd_n - ) - else: - dropout_mask_ptrs = None - philox_ptrs = None - - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=tl.float32) - if (BLOCK_DMODEL == BLOCK_DMODEL_POW2): - q_mask = (offs_m[:, None] < seqlen_q) - else: - q_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) - q = tl.load(q_ptrs, mask=q_mask, other=0.0) - if IS_FP8: - descale_q = tl.load(descale_q_ptr + off_z * stride_descale_q_z + off_q_head) - descale_k = tl.load(descale_k_ptr + off_z * stride_descale_k_z + off_k_head) - descale_v = tl.load(descale_v_ptr + off_z * stride_descale_v_z + off_k_head) - else: - descale_q, descale_k ,descale_v = 1.0, 1.0, 1.0 - - n_extra_tokens = 0 - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N -seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - - #if CAUSAL, then determine masked_blocks and full blocks - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked blocks. - # Additionally there might be one more due to dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. - # In this case we might exceed n_blocks so pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its actual - # value because there is no masking. Similarly we do not need padding. - if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - stride_kn, - stride_vn, - stride_sd_n, - start_m, - seqlen_k, - seqlen_q, - dropout_p, - s_dmask_ptrs, dropout_mask_ptrs, philox_seed, philox_ptrs, - block_min, block_max, 0, 0, 0, alibi_slope, - descale_q, descale_k, descale_v, - offs_m, offs_n, BLOCK_M, BLOCK_N, BLOCK_DMODEL,BLOCK_DMODEL_POW2, - sm_scale, False, MASK_STEPS=False, ENABLE_DROPOUT=ENABLE_DROPOUT, - RETURN_SCORES=RETURN_SCORES, PADDED_HEAD=BLOCK_DMODEL!=BLOCK_DMODEL_POW2, - IS_FP8=IS_FP8, FP8_MAX=FP8_MAX - ) - block_min = block_max - block_max = n_blocks * BLOCK_N - - # Remaining blocks, if any, are full / not masked. - if (masked_blocks > 0): - if IS_CAUSAL: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) - else: - offs_n_causal = 0 - k_ptrs += n_full_blocks * BLOCK_N * stride_kn - v_ptrs += n_full_blocks * BLOCK_N * stride_vn - if RETURN_SCORES: - s_dmask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n - if ENABLE_DROPOUT: - dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n - acc, l_i, m_i = _attn_fwd_inner(acc, - l_i, - m_i, - q, - k_ptrs, - v_ptrs, - stride_kn, stride_vn, stride_sd_n, - start_m, seqlen_k, seqlen_q, - dropout_p, - s_dmask_ptrs, dropout_mask_ptrs, philox_seed, philox_ptrs, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, - descale_q, descale_k, descale_v, - offs_m, offs_n, BLOCK_M, BLOCK_N, BLOCK_DMODEL,BLOCK_DMODEL_POW2, - sm_scale, IS_CAUSAL, MASK_STEPS=True, ENABLE_DROPOUT=ENABLE_DROPOUT, - RETURN_SCORES=RETURN_SCORES, PADDED_HEAD=BLOCK_DMODEL!=BLOCK_DMODEL_POW2, - IS_FP8=IS_FP8, FP8_MAX=FP8_MAX - ) - # epilogue - # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. - l_recip = 1 / l_i[:, None] - acc = acc * l_recip - if ENABLE_DROPOUT: - dropout_scale = 1 / (1 - dropout_p) - acc = acc * dropout_scale - # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, - # then we have one block with a row of all NaNs which come from computing - # softmax over a row of all -infs (-inf - inf = NaN). We check for that here - # and store 0s where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - if IS_CAUSAL: - if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL_POW2, ), causal_start_idx, dtype=tl.int32) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] - z = 0.0 - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - - # write back LSE(Log Sum Exponents), the log of the normalization constant - overflow_size = end_m_idx - seqlen_q - if softmax_lse_ptr is not None: - RCP_LN2: tl.constexpr = 1.4426950408889634 - LN2: tl.constexpr = 0.6931471824645996 - # compute log-sum-exp in base 2 units - mi_base2 = m_i * RCP_LN2 - softmax_lse = mi_base2 + tl.math.log2(l_i) - # convert back to natural units - softmax_lse *= LN2 - - if IS_CAUSAL: - # zero out nans caused by -infs when doing causal - lse_causal_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx - softmax_lse = tl.where(lse_causal_mask, 0.0, softmax_lse) - - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. - # This is only true for the last M block. For others, overflow_size will be -ve - offs_lse = off_z * stride_lse_z + off_q_head * stride_lse_h + cu_seqlens_q_start * stride_lse_m + offs_m*stride_lse_m - if overflow_size > 0: - boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) - lse_mask = tl.arange(0, BLOCK_M) < boundary - tl.store(softmax_lse_ptr + offs_lse, softmax_lse, mask=lse_mask) # the log of the normalization constant - else: - tl.store(softmax_lse_ptr + offs_lse, softmax_lse) # the log of the normalization constant - - # write back O - offs_out = (off_z * stride_oz + - off_q_head * stride_oh + - cu_seqlens_q_start * stride_om + - offs_m[:, None] * stride_om + - offs_d[None, :] * stride_on) - out_mask = tl.full([BLOCK_M, BLOCK_DMODEL_POW2], 1, dtype=tl.int1) - if overflow_size > 0: - out_mask = out_mask & (offs_m[:, None] < seqlen_q) - if BLOCK_DMODEL != BLOCK_DMODEL_POW2: - out_mask = out_mask & (offs_d[None, :] < BLOCK_DMODEL) - op = acc.to(out_ptr.dtype.element_ty) - tl.store(out_ptr + offs_out, op, mask=out_mask) - -def _flash_attn_forward( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - alibi_slopes: Optional[torch.Tensor], - return_lse: bool, - return_softmax: bool, - max_seqlen_q: int, - max_seqlen_k: int, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - #FP8 - IS_FP8 = is_fp8(q) - FP8_MAX: tl.constexpr=torch.finfo(q.dtype).max - is_varlen = True if cu_seqlens_q is not None else False - - if IS_FP8: - o = torch.zeros_like(q, dtype=torch.float32) - else: - o = torch.zeros_like(q) - if is_varlen: - #Layout for q,k,v is thd ie [total_tokens, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] - seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - #Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = q.shape - seqlen_k = k.shape[1] - num_k_heads = k.shape[2] - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - - #padding for head_dim. Power of 2 or 16 - BLOCK_DMODEL_POW2 = triton.next_power_of_2(head_sz) - BLOCK_DMODEL_POW2 = max(BLOCK_DMODEL_POW2, 16) - - #softmax_lse [batch, num_q_heads, seqlen_q] - if return_lse: - if is_varlen: - softmax_lse = torch.zeros((q.shape[0], num_q_heads), device=q.device, dtype=torch.float32) - stride_lse_z, stride_lse_h, stride_lse_m = 0, softmax_lse.stride(1), softmax_lse.stride(0) - else: - softmax_lse = torch.zeros((batch, num_q_heads, max_seqlen_q), device=q.device, dtype=torch.float32) - stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() - else: - softmax_lse = None - - #exp_scores [batch, num_q_heads, seqlen_q, seqlen_k] - enable_dropout = dropout_p > 0.0 - if enable_dropout: - philox_seed = torch.randint(0, 0xffffff, (1,))[0].item() #No specific reason to restrict range to 0xffffff - philox_offset = torch.randint(0, 0xffffff, (1,))[0].item() #Pass in an int, not Tensor - else: - philox_seed = 0 - philox_offset = 0 - if return_softmax or enable_dropout: - s_dmask = torch.zeros((batch, num_q_heads, max_seqlen_q, max_seqlen_k), device=q.device, dtype=torch.float32) - dropout_mask = torch.zeros((batch, num_q_heads, max_seqlen_q, max_seqlen_k), device=q.device, dtype=torch.float32) - else: - s_dmask = None - dropout_mask = None - - - # Best config from ROCm/triton/python/perf-kernels/flash_attention.py::attn_fwd autotuning is BLOCK_M: 128, BLOCK_N: 64, waves_per_eu: 2, num_warps: 4, num_ctas: 1, num_stages: 1 - # Tuned for MI300x - config = { - 'BLOCK_M': 128, - 'BLOCK_N': 64, - 'waves_per_eu': 2, - 'num_warps': 4, - 'num_ctas': 1, - 'num_stages': 1, - } - - grid = lambda META:(batch, num_q_heads, triton.cdiv(seqlen_q, META['BLOCK_M'])) - _attn_fwd[grid](q, - k, - v, - descale_q, - descale_k, - descale_v, - o, - alibi_slopes, - s_dmask, - dropout_mask, - softmax_lse, - *q_strides, - *k_strides, - *v_strides, - descale_q.stride(0) if descale_q is not None else 0, - descale_k.stride(0) if descale_k is not None else 0, - descale_v.stride(0) if descale_v is not None else 0, - *o_strides, - alibi_slopes.stride(0) if alibi_slopes is not None else 0, - alibi_slopes.stride(1) if alibi_slopes is not None else 0, - s_dmask.stride(0) if s_dmask is not None else 0, - s_dmask.stride(1) if s_dmask is not None else 0, - s_dmask.stride(2) if s_dmask is not None else 0, - s_dmask.stride(3) if s_dmask is not None else 0, - stride_lse_z if softmax_lse is not None else 0, - stride_lse_h if softmax_lse is not None else 0, - stride_lse_m if softmax_lse is not None else 0, - softmax_scale, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset, - SEQLEN_Q=max_seqlen_q, - SEQLEN_K=max_seqlen_k, - IS_CAUSAL=causal, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_DMODEL=head_sz, - BLOCK_DMODEL_POW2=BLOCK_DMODEL_POW2, - RETURN_SCORES=return_softmax, - ENABLE_DROPOUT=enable_dropout, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - VARLEN=is_varlen, - **config - ) - - return o, softmax_lse, s_dmask, philox_seed, philox_offset - # This function computes delta given output Out and gradient DO # Here is the I/O shape: # Out: (batch, nhead_q, max_seqlens_q, headDim) @@ -2261,7 +1478,7 @@ def _bwd_kernel_dq_noncausal( dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) -def _flash_attn_backward( +def attention_prefill_backward_triton_fused_atomics_impl( do: torch.Tensor, q: torch.Tensor, k: torch.Tensor, @@ -2589,702 +1806,4 @@ def _flash_attn_backward( waves_per_eu=WAVES_PER_EU, ) - return delta - - -class FlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - is_grad_enabled, - fused_backward, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q,k,v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - - - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0, - max_seqlen_q=q.shape[1], - max_seqlen_k=k.shape[1], - ) - - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse) - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - ctx.fused_backward = fused_backward - - - out = out_padded[..., :head_size_og] - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - - return tuple(result) - - @staticmethod - def backward(ctx, do, *args): - q, k, v, out, softmax_lse = ctx.saved_tensors - dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) - head_size_v_og = do.size(3) - do_padded = do - if head_size_v_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) - _flash_attn_backward( - do_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - None, - None, - max_seqlen_q=q.shape[1], - max_seqlen_k=k.shape[1], - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - fused=ctx.fused_backward, - ) - dq = dq[..., : q.shape[-1]] # We could have padded the head dimension - dk = dk[..., : k.shape[-1]] - dv = dv[..., : v.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - -def flash_attn_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1,-1), - alibi_slopes=None, - deterministic=True, - return_lse=False, - return_attn_probs=False, - fused_backward=False, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_lse=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnFunc.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - torch.is_grad_enabled(), - fused_backward, - ) - - -class FlashAttnFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - is_grad_enabled, - fused_backward, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q,k,v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # cast input to fp8 - fp8_dtype = torch.float8_e4m3fnuz - q_fp8, descale_q = cast_to_fp8(q, fp8_dtype, "bshd") - k_fp8, descale_k = cast_to_fp8(k, fp8_dtype, "bshd") - v_fp8, descale_v = cast_to_fp8(v, fp8_dtype, "bshd") - - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0, - max_seqlen_q=q.shape[1], - max_seqlen_k=k.shape[1], - cu_seqlens_q=None, - cu_seqlens_k=None, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - ) - - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_padded, softmax_lse, descale_q, descale_k, descale_v) - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.fused_backward = fused_backward - - out = out_padded[..., :head_size_og] - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - - return tuple(result) - - @staticmethod - def backward(ctx, do, *args): - q_fp8, k_fp8, v_fp8, out, softmax_lse, descale_q, descale_k, descale_v = ctx.saved_tensors - dq, dk, dv = torch.zeros_like(q_fp8, dtype=torch.float32), torch.zeros_like(k_fp8, dtype=torch.float32), torch.zeros_like(v_fp8, dtype=torch.float32) - head_size_v_og = do.size(3) - do_padded = do - if head_size_v_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) - - fp8_dtype = torch.float8_e4m3fnuz - do_padded_fp8, descale_do = cast_to_fp8(do_padded, fp8_dtype, "bshd") - _flash_attn_backward( - do_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - dq, - dk, - dv, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - None, - None, - max_seqlen_q=q_fp8.shape[1], - max_seqlen_k=k_fp8.shape[1], - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_do=descale_do, - fused=ctx.fused_backward, - ) - #dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension - #dk = dk[..., : k_fp8.shape[-1]] - #dv = dv[..., : v_fp8.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None - -def flash_attn_fp8_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_lse=False, - return_attn_probs=False, - fused_backward=False, -): - return FlashAttnFP8Func.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - torch.is_grad_enabled(), - fused_backward, - ) - -class FlashAttnVarlenFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - block_table, - is_grad_enabled, - fused_backward, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( - q, - k, - v, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0.0, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - ) - if is_grad: - ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.fused_backward = fused_backward - out = out_padded[..., :head_size_og] - - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - - return tuple(result) - - @staticmethod - def backward(ctx, do, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors - dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) - head_size_og = do.size(2) - do_padded = do - if head_size_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) - _flash_attn_backward( - do_padded, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_k=ctx.max_seqlen_k, - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - fused=ctx.fused_backward, - ) - dq = dq[..., : q.shape[-1]] # We could have padded the head dimension - dk = dk[..., : k.shape[-1]] - dv = dv[..., : v.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1,-1), - alibi_slopes=None, - deterministic=False, - return_lse=False, - return_attn_probs=False, - block_table=None, - fused_backward=False, -): - """dropout_p should be set to 0.0 during evaluation - Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Arguments: - q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. - k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. - cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into q. - cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into kv. - max_seqlen_q: int. Maximum query sequence length in the batch. - max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of - (-alibi_slope * |i + seqlen_k - seqlen_q - j|) - is added to the attention score of query i and key j. - deterministic: bool. Whether to use the deterministic implementation of the backward pass, - which is slightly slower and uses more memory. The forward pass is always deterministic. - return_attn_probs: bool. Whether to return the attention probabilities. This option is for - testing only. The returned probabilities are not guaranteed to be correct - (they might not have the right scaling). - Return: - out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). - """ - return FlashAttnVarlenFunc.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - block_table, - torch.is_grad_enabled(), - fused_backward, - ) - - -class FlashAttnVarlenFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - block_table, - is_grad_enabled, - fused_backward, - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # cast input to fp8 - fp8_dtype = torch.float8_e4m3fnuz - q_fp8, descale_q = cast_varlen_to_fp8(q, fp8_dtype, cu_seqlens=cu_seqlens_q) - k_fp8, descale_k = cast_varlen_to_fp8(k, fp8_dtype, cu_seqlens=cu_seqlens_k) - v_fp8, descale_v = cast_varlen_to_fp8(v, fp8_dtype, cu_seqlens=cu_seqlens_k) - - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( - q_fp8, - k_fp8, - v_fp8, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - fused_backward=fused_backward, - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, descale_q, descale_k, descale_v) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - ctx.fused_backward = fused_backward - out = out_padded[..., :head_size_og] - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - - return tuple(result) - - @staticmethod - def backward(ctx, do, *args): - q_fp8, k_fp8, v_fp8, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, descale_q, descale_k, descale_v = ctx.saved_tensors - dq, dk, dv = torch.zeros_like(q_fp8, dtype=torch.float32), torch.zeros_like(k_fp8, dtype=torch.float32), torch.zeros_like(v_fp8, dtype=torch.float32) - head_size_v_og = do.size(3) - do_padded = do - if head_size_v_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) - - fp8_dtype = torch.float8_e4m3fnuz - do_padded_fp8, descale_do = cast_varlen_to_fp8(do_padded, fp8_dtype, "thd", cu_seqlens_q) - - _flash_attn_backward( - do_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - dq, - dk, - dv, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_k=ctx.max_seqlen_k, - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_do=descale_do - ) - dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension - dk = dk[..., : k_fp8.shape[-1]] - dv = dv[..., : v_fp8.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None - -def flash_attn_varlen_fp8_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_lse=False, - return_attn_probs=False, - block_table=None, - fused_backward=False, -): - return FlashAttnVarlenFP8Func.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - block_table, - torch.is_grad_enabled(), - fused_backward, - ) + return delta \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py similarity index 89% rename from flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py rename to flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py index 67f7498f083..8bdcfd10d6a 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py @@ -1,10 +1,12 @@ +import os import torch import triton # type: ignore import triton.language as tl # type: ignore from typing import Literal, Optional -from .utils import DEBUG, AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, get_shapes_from_layout, compute_fp8_scaling_factors, \ - get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_cdna, is_fp8, is_rdna +from .utils import AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, \ + create_dropout_mask, create_dropout_mask_varlen, is_cdna, is_fp8, is_rdna, round_multiple +DEBUG= False # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) @@ -109,7 +111,7 @@ def _bwd_preprocess( O, DO, # noqa: E741 Delta, stride_ob, stride_oh, stride_om, stride_od, - stride_deltab, stride_deltah, stride_deltam, + stride_delta_b, stride_delta_h, stride_delta_m, stride_descale_do_z, cu_seqlens_q, max_seqlen_q, Descale_do, @@ -160,8 +162,8 @@ def _bwd_preprocess( delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) else: delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - delta_offset = Delta + bid * stride_deltab + hid * stride_deltah + q_start * stride_deltam - tl.store(delta_offset + offs_m * stride_deltam, delta, mask=mask_m) + delta_offset = Delta + bid * stride_delta_b + hid * stride_delta_h + q_start * stride_delta_m + tl.store(delta_offset + offs_m * stride_delta_m, delta, mask=mask_m) # The main inner-loop logic for computing dK and dV. @@ -172,7 +174,7 @@ def _bwd_dkdv_inner( stride_qm, stride_qk, stride_dom, stride_dok, stride_dropoutm, stride_dropoutn, - stride_deltam, + stride_lse_m, stride_delta_m, BLOCK_M: tl.constexpr, # 16 BLOCK_N: tl.constexpr, # 128 HEAD_DIM: tl.constexpr, # @@ -243,7 +245,7 @@ def _bwd_dkdv_inner( dropout_mask = rand_vals > dropout_p dropout_scale = 1.0 / (1 - dropout_p) # Load m before computing qk to reduce pipeline stall. - m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + m = tl.load(M + offs_m * stride_lse_m, mask=mask_m, other=0.0) if IS_FP8: qkT = (tl.dot(k, qT) * descale_q * descale_k) else: @@ -297,7 +299,7 @@ def _bwd_dkdv_inner( if start_n == 256: print(f"pT: {pT.shape}\n", pT) # D (= delta) is pre-divided by ds_scale. - Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + Di = tl.load(D + offs_m * stride_delta_m, mask=mask_m) # Compute dP and dS. if IS_FP8: dpT = (tl.dot(v, tl.trans(do)) * descale_v * descale_do) @@ -326,7 +328,8 @@ def _bwd_dq_inner( # shared by Q/K/V. stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, stride_dropoutm, stride_dropoutn, # stride for dropout - stride_deltam, + stride_lse_m, + stride_delta_m, seqlen_q, seqlen_k, # BLOCK_M2: tl.constexpr, # BLOCK_N2: tl.constexpr, # @@ -359,7 +362,7 @@ def _bwd_dq_inner( kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk # D (= delta) is pre-divided by ds_scale. - Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) + Di = tl.load(Delta + offs_m * stride_delta_m, mask=mask_m, other=0.0) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) curr_n = start_n @@ -458,7 +461,8 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_dqb, stride_dqh, stride_dqm, stride_dqd, stride_dkb, stride_dkh, stride_dkn, stride_dkd, stride_dvb, stride_dvh, stride_dvn, stride_dvd, - stride_deltab, stride_deltah, stride_deltam, + stride_lse_b, stride_lse_h, stride_lse_m, + stride_delta_b, stride_delta_h, stride_delta_m, stride_dob, stride_doh, stride_dom, stride_dod, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, @@ -568,10 +572,10 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba Q_ptr = Q + adj_q adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom DO_ptr = DO + adj_do - adj_delta = bid * stride_deltab + hqid * stride_deltah + \ - q_start * stride_deltam - M_ptr = M + adj_delta + adj_delta = bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m if USE_ALIBI: alibi_offset = bid * stride_az + hqid * stride_ah @@ -614,7 +618,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_qm, stride_qd, # strides for q stride_dom, stride_dod, # strides for o stride_dropoutm, stride_dropoutn, # strides for dropout - stride_deltam, + stride_lse_m, stride_delta_m, MASK_BLOCK_M1, BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, @@ -644,7 +648,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_qm, stride_qd, # strides for q stride_dom, stride_dod, # strides for o stride_dropoutm, stride_dropoutn, # strides for dropout - stride_deltam, + stride_lse_m, stride_delta_m, BLOCK_M1, BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, @@ -705,8 +709,10 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom adj_delta = \ - bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m if USE_ALIBI: alibi_offset = bid * stride_az + hqid * stride_ah @@ -726,7 +732,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(M + adj_delta + offs_m * stride_deltam, + m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) m = m[:, None] @@ -749,7 +755,8 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba q, K, V, do, m, Delta_ptr, sm_scale, stride_qm, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd, stride_dropoutm, stride_dropoutn, - stride_deltam, + stride_lse_m, + stride_delta_m, seqlen_q, seqlen_k, BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, ACTUAL_HEAD_DIM, @@ -775,7 +782,8 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba q, K, V, do, m, Delta_ptr, sm_scale, stride_qm, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd, stride_dropoutm, stride_dropoutn, - stride_deltam, + stride_lse_m, + stride_delta_m, seqlen_q, seqlen_k, BLOCK_M2, BLOCK_N2, HEAD_DIM, ACTUAL_HEAD_DIM, @@ -814,7 +822,8 @@ def bwd_kernel_noncausal( stride_dqb, stride_dqh, stride_dqm, stride_dqd, stride_dkb, stride_dkh, stride_dkn, stride_dkd, stride_dvb, stride_dvh, stride_dvn, stride_dvd, - stride_deltab, stride_deltah, stride_deltam, + stride_lse_b, stride_lse_h, stride_lse_m, + stride_delta_b, stride_delta_h, stride_delta_m, stride_dob, stride_doh, stride_dom, stride_dod, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, @@ -890,9 +899,10 @@ def bwd_kernel_noncausal( Q_ptr = Q + adj_q adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom DO_ptr = DO + adj_do - adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - M_ptr = M + adj_delta + adj_delta = bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m if USE_ALIBI: alibi_offset = bid * stride_az + hqid * stride_ah @@ -927,7 +937,8 @@ def bwd_kernel_noncausal( stride_qm, stride_qd, # strides for q stride_dom, stride_dod, # strides for o stride_dropoutm, stride_dropoutn, # strides for dropout - stride_deltam, + stride_lse_m, + stride_delta_m, BLOCK_M1, BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, # @@ -974,8 +985,10 @@ def bwd_kernel_noncausal( adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom adj_delta = \ - bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m if USE_ALIBI: alibi_offset = bid * stride_az + hqid * stride_ah @@ -996,7 +1009,7 @@ def bwd_kernel_noncausal( q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(M + adj_delta + offs_m * stride_deltam, + m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) m = m[:, None] @@ -1019,7 +1032,8 @@ def bwd_kernel_noncausal( q, K, V, do, m, Delta_ptr, sm_scale, stride_qm, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd, stride_dropoutm, stride_dropoutn, - stride_deltam, + stride_lse_m, + stride_delta_m, seqlen_q, seqlen_k, BLOCK_M2, BLOCK_N2, HEAD_DIM, ACTUAL_HEAD_DIM, @@ -1048,8 +1062,11 @@ def is_contiguous(x, name): else: print(f"{name} is not contiguous") return x.contiguous() + + +OLD_LSE = os.environ.get('OLD_LSE', '0').lower() in ('1', 'true', 'yes') -def attention_prefill_backward_triton_split_oneKernel_impl( +def attention_prefill_backward_triton_split_fused_no_atomics_impl( do: torch.Tensor, q: torch.Tensor, k: torch.Tensor, @@ -1120,27 +1137,44 @@ def attention_prefill_backward_triton_split_oneKernel_impl( stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = stride_descale_do_z = None - # get strides and shape - batch, nheads_q, nheads_k, head_size, max_seqlen_q_final, max_seqlen_k_final = \ - get_shapes_from_layout( - q, k, layout, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k - ) - q_strides, k_strides, v_strides, o_strides = \ - get_strides_from_layout(q, k, v, o, layout) - stride_qb, stride_qh, stride_qm, stride_qd = q_strides - stride_kb, stride_kh, stride_kn, stride_kd = k_strides - stride_vb, stride_vh, stride_vn, stride_vd = v_strides - stride_ob, stride_oh, stride_om, stride_od = o_strides - dq_strides, dk_strides, dv_strides, do_strides = \ - get_strides_from_layout(dq, dk, dv, do, layout) - stride_dqb, stride_dqh, stride_dqm, stride_dqd = dq_strides - stride_dkb, stride_dkh, stride_dkn, stride_dkd = dk_strides - stride_dvb, stride_dvh, stride_dvn, stride_dvd = dv_strides - stride_dob, stride_doh, stride_dom, stride_dod = do_strides + # get params, strides and shape IS_VARLEN = layout == "thd" use_dropout = (dropout_p > 0.0) + + # get shapes and strides + if IS_VARLEN: + # shape + _, nheads_q, head_size = q.shape + _, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + max_seqlen_q_final = max_seqlen_q + max_seqlen_k_final = max_seqlen_k + + # strides + stride_qb, stride_qh, stride_qm, stride_qd = 0, q.stride(1), q.stride(0), q.stride(2) + stride_kb, stride_kh, stride_kn, stride_kd = 0, k.stride(1), k.stride(0), k.stride(2) + stride_vb, stride_vh, stride_vn, stride_vd = 0, v.stride(1), v.stride(0), v.stride(2) + stride_ob, stride_oh, stride_om, stride_od = 0, o.stride(1), o.stride(0), o.stride(2) + stride_dqb, stride_dqh, stride_dqm, stride_dqd = 0, dq.stride(1), dq.stride(0), dq.stride(2) + stride_dkb, stride_dkh, stride_dkn, stride_dkd = 0, dk.stride(1), dk.stride(0), dk.stride(2) + stride_dvb, stride_dvh, stride_dvn, stride_dvd = 0, dv.stride(1), dv.stride(0), dv.stride(2) + stride_dob, stride_doh, stride_dom, stride_dod = 0, do.stride(1), do.stride(0), do.stride(2) + stride_lse_b, stride_lse_h, stride_lse_m = (0, softmax_lse.stride(0), softmax_lse.stride(1)) + else: + # shapes + batch, max_seqlen_q_final, nheads_q, head_size = q.shape + _, max_seqlen_k_final, nheads_k, _ = k.shape + + # strides + stride_qb, stride_qh, stride_qm, stride_qd = q.stride(0), q.stride(2), q.stride(1), q.stride(3) + stride_kb, stride_kh, stride_kn, stride_kd = k.stride(0), k.stride(2), k.stride(1), k.stride(3) + stride_vb, stride_vh, stride_vn, stride_vd = v.stride(0), v.stride(2), v.stride(1), v.stride(3) + stride_ob, stride_oh, stride_om, stride_od = o.stride(0), o.stride(2), o.stride(1), o.stride(3) + stride_dqb, stride_dqh, stride_dqm, stride_dqd = dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3) + stride_dkb, stride_dkh, stride_dkn, stride_dkd = dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3) + stride_dvb, stride_dvh, stride_dvn, stride_dvd = dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3) + stride_dob, stride_doh, stride_dom, stride_dod = do.stride(0), do.stride(2), do.stride(1), do.stride(3) + stride_lse_b, stride_lse_h, stride_lse_m = softmax_lse.stride() use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) # get closest power of 2 over or equal to 32. @@ -1150,18 +1184,35 @@ def attention_prefill_backward_triton_split_oneKernel_impl( ACTUAL_HEAD_DIM = head_size # init delta - delta = torch.empty_like(softmax_lse) - if IS_VARLEN: - stride_deltab = 0 - stride_deltam, stride_deltah = delta.stride() + if OLD_LSE: + delta = torch.empty_like(softmax_lse) + if IS_VARLEN: + stride_delta_b, stride_delta_h, stride_delta_m = 0, delta.stride(0), delta.stride(1) + else: + stride_delta_b, stride_delta_h, stride_delta_m = delta.stride() else: - stride_deltab, stride_deltah, stride_deltam = delta.stride() + if IS_VARLEN: + # interface expects the varlen sequence dims to rounded like this. Not sure why. + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + total_q_rounded = total_q + 128 * batch_size + delta_padded = torch.zeros((nheads_q, total_q_rounded), device=q.device, dtype=torch.float32) + delta = delta_padded[:, :total_q] + stride_delta_b, stride_delta_h, stride_delta_m = 0, delta.stride(0), delta.stride(1) + else: + # the interface expects the sequence dimension to be rounded to 128 + max_seqlen_q_rounded = round_multiple(max_seqlen_q_final, 128) + delta_padded = torch.zeros((batch, nheads_q, max_seqlen_q_rounded), + device=softmax_lse.device, dtype=torch.float32) + delta = delta_padded[:, :, :max_seqlen_q_final] + stride_delta_b, stride_delta_h, stride_delta_m = delta.stride() + pre_grid = lambda META: (triton.cdiv(max_seqlen_q_final, META['PRE_BLOCK']), batch, nheads_q) _bwd_preprocess[pre_grid]( o, do, delta, stride_ob, stride_oh, stride_om, stride_od, - stride_deltab, stride_deltah, stride_deltam, + stride_delta_b, stride_delta_h, stride_delta_m, stride_descale_do_z, cu_seqlens_q, max_seqlen_q_final, descale_do, @@ -1214,7 +1265,8 @@ def attention_prefill_backward_triton_split_oneKernel_impl( stride_dqb, stride_dqh, stride_dqm, stride_dqd, stride_dkb, stride_dkh, stride_dkn, stride_dkd, stride_dvb, stride_dvh, stride_dvn, stride_dvd, - stride_deltab, stride_deltah, stride_deltam, + stride_lse_b, stride_lse_h, stride_lse_m, + stride_delta_b, stride_delta_h, stride_delta_m, stride_dob, stride_doh, stride_dom, stride_dod, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, @@ -1247,7 +1299,8 @@ def attention_prefill_backward_triton_split_oneKernel_impl( stride_dqb, stride_dqh, stride_dqm, stride_dqd, stride_dkb, stride_dkh, stride_dkn, stride_dkd, stride_dvb, stride_dvh, stride_dvn, stride_dvd, - stride_deltab, stride_deltah, stride_deltam, + stride_lse_b, stride_lse_h, stride_lse_m, + stride_delta_b, stride_delta_h, stride_delta_m, stride_dob, stride_doh, stride_dom, stride_dod, stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, @@ -1271,4 +1324,7 @@ def attention_prefill_backward_triton_split_oneKernel_impl( DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) - return delta \ No newline at end of file + if OLD_LSE: + return delta + else: + return delta_padded diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py index 5cc93edc5e4..2728dca7349 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py @@ -2,9 +2,11 @@ import triton # type: ignore import triton.language as tl # type: ignore from typing import Literal, Optional -from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, get_shapes_from_layout, \ +from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, get_shapes_from_layout, \ get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_fp8 +DEBUG = False + # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) diff --git a/flash_attn/flash_attn_triton_amd/bwd_ref.py b/flash_attn/flash_attn_triton_amd/bwd_ref.py index 639211a51f6..8bdccb1d329 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/bwd_ref.py @@ -1,8 +1,9 @@ import torch import math from typing import Literal, Optional -from .utils import DEBUG, compute_alibi_tensor_ref +from .utils import compute_alibi_tensor_ref +DEBUG = False DEBUG_CORE = False def attention_backward_core_ref_impl( @@ -196,8 +197,8 @@ def attention_varlen_backward_pytorch_ref_impl( dq = torch.zeros_like(q) dk = torch.zeros_like(k) dv = torch.zeros_like(v) - # delta has the same shape as softmax_lse: [total_L_q, nheads_q] - delta = torch.zeros((total_L_q, nheads_q), dtype=torch.float32, device=o.device) + # delta has the same shape as softmax_lse + delta = torch.zeros_like(softmax_lse) for i in range(batch_size): # Get the start and end indices for the current sequence @@ -212,7 +213,7 @@ def attention_varlen_backward_pytorch_ref_impl( v_i = v[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] do_i = do[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] o_i = o[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] - softmax_lse_i = softmax_lse[start_q:end_q, :] # [L_q_i, nheads_q] + softmax_lse_i = softmax_lse[:, start_q:end_q] # [nheads_q, L_q_i] if group_size != 1: # MQA or GQA case @@ -220,7 +221,7 @@ def attention_varlen_backward_pytorch_ref_impl( q_i = q_i.view(q_i.shape[0], nheads_k, group_size, head_dim) do_i = do_i.view(do_i.shape[0], nheads_k, group_size, head_dim) o_i = o_i.view(o_i.shape[0], nheads_k, group_size, head_dim) - softmax_lse_i = softmax_lse_i.view(softmax_lse_i.shape[0], nheads_k, group_size) + softmax_lse_i = softmax_lse_i.view(nheads_k, group_size, softmax_lse_i.shape[1]) # Expand k_i and v_i to match group_size k_i = k_i.unsqueeze(2).expand(-1, -1, group_size, -1) v_i = v_i.unsqueeze(2).expand(-1, -1, group_size, -1) @@ -228,16 +229,17 @@ def attention_varlen_backward_pytorch_ref_impl( q_i = q_i.reshape(q_i.shape[0], nheads_k * group_size, head_dim) do_i = do_i.reshape(do_i.shape[0], nheads_k * group_size, head_dim) o_i = o_i.reshape(o_i.shape[0], nheads_k * group_size, head_dim) - softmax_lse_i = softmax_lse_i.reshape(softmax_lse_i.shape[0], nheads_k * group_size) + softmax_lse_i = softmax_lse_i.reshape(nheads_k * group_size, softmax_lse_i.shape[2]) k_i = k_i.reshape(k_i.shape[0], nheads_k * group_size, head_dim) v_i = v_i.reshape(v_i.shape[0], nheads_k * group_size, head_dim) + # Permute to [nheads_total, L, head_dim] q_i = q_i.permute(1, 0, 2) k_i = k_i.permute(1, 0, 2) v_i = v_i.permute(1, 0, 2) do_i = do_i.permute(1, 0, 2) o_i = o_i.permute(1, 0, 2) - softmax_lse_i = softmax_lse_i.transpose(0, 1) + if alibi_slopes is not None: alibi_slopes_i = alibi_slopes[i] else: @@ -264,7 +266,6 @@ def attention_varlen_backward_pytorch_ref_impl( dq_i = dq_i.permute(1, 0, 2) # [L_q_i, nheads_total, head_dim] dk_i = dk_i.permute(1, 0, 2) # [L_k_i, nheads_total, head_dim] dv_i = dv_i.permute(1, 0, 2) # [L_k_i, nheads_total, head_dim] - delta_i = delta_i.transpose(1, 0) # [L_q_i, nheads_total] if group_size != 1: # Reshape dq_i and delta_i back to original shape @@ -286,7 +287,7 @@ def attention_varlen_backward_pytorch_ref_impl( dq[start_q:end_q, :, :] = dq_i dk[start_k:end_k, :, :] += dk_i # Accumulate gradients for shared keys dv[start_k:end_k, :, :] += dv_i # Accumulate gradients for shared values - delta[start_q:end_q, :] = delta_i + delta[:, start_q:end_q] = delta_i return dq, dk, dv, delta diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index 3f2d92c22d6..e165d714876 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -2,7 +2,9 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna +from .utils import AUTOTUNE, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna + +DEBUG = False def get_cdna_autotune_configs(): return [ diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 4d9a7bf17b1..e33982bb6a7 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -2,7 +2,9 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_arch, get_shapes_from_layout, get_strides_from_layout, is_cdna, is_fp8, is_rdna, create_dropout_mask +from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_arch, is_cdna, is_fp8, is_rdna, create_dropout_mask + +DEBUG = False # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) @@ -610,7 +612,7 @@ def attention_prefill_forward_triton_impl( stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = None # check flags - is_varlen = layout == "thd" + IS_VARLEN = layout == "thd" use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) is_inference = False if cache_seqlens is None else True if is_inference: @@ -622,8 +624,36 @@ def attention_prefill_forward_triton_impl( if (bias is not None): assert (bias.numel() < 2**31) - batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shapes_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k) - q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) + # get shape and strides + if IS_VARLEN: # thd layout + # shape + total_q, nheads_q, head_size = q.shape + _, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + + # softmax_lse is the log of the normalization constant / sum of expoential score(unnormalzied probablities) + softmax_lse = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + + # strides + stride_qb, stride_qh, stride_qm, stride_qd = 0, q.stride(1), q.stride(0), q.stride(2) + stride_kb, stride_kh, stride_kn, stride_kd = 0, k.stride(1), k.stride(0), k.stride(2) + stride_vb, stride_vh, stride_vn, stride_vd = 0, v.stride(1), v.stride(0), v.stride(2) + stride_ob, stride_oh, stride_om, stride_od = 0, o.stride(1), o.stride(0), o.stride(2) + stride_lse_z, stride_lse_h, stride_lse_m = 0, softmax_lse.stride(0), softmax_lse.stride(1) + else: # bshd layout + # shape + batch, seqlen_q, nheads_q, head_size = q.shape + _, _, nheads_k, _ = k.shape + + # softmax_lse is the log of the normalization constant / sum of expoential score(unnormalzied probablities) + softmax_lse = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) + + # strides + stride_qb, stride_qh, stride_qm, stride_qd = q.stride(0), q.stride(2), q.stride(1), q.stride(3) + stride_kb, stride_kh, stride_kn, stride_kd = k.stride(0), k.stride(2), k.stride(1), k.stride(3) + stride_vb, stride_vh, stride_vn, stride_vd = v.stride(0), v.stride(2), v.stride(1), v.stride(3) + stride_ob, stride_oh, stride_om, stride_od = o.stride(0), o.stride(2), o.stride(1), o.stride(3) + stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() # Get closest power of 2 over or equal to 32. padded_d_model = 1 << (head_size - 1).bit_length() @@ -650,35 +680,34 @@ def attention_prefill_forward_triton_impl( else: dropout_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, dtype=torch.float32) - scores_strides = (sd_mask.stride(0), sd_mask.stride(1), sd_mask.stride(2), sd_mask.stride(3)) + stride_sz, stride_sh, stride_sm, stride_sn = (sd_mask.stride(0), sd_mask.stride(1), sd_mask.stride(2), sd_mask.stride(3)) else: sd_mask = None dropout_mask = None - scores_strides = (0, 0, 0, 0) - - # stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities) - if is_varlen: - total_seqlen_q, _, _ = q.shape - softmax_lse = torch.zeros((nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32) - stride_lse_h, stride_lse_m = softmax_lse.stride() - stride_lse_z = 0 - else: - softmax_lse = torch.zeros((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) - stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() + stride_sz, stride_sh, stride_sm, stride_sn = (0, 0, 0, 0) + if bias is not None: - bias_strides = (bias.stride(0), bias.stride(1),bias.stride(2), + stride_bz, stride_bh, stride_bm, stride_bn = (bias.stride(0), bias.stride(1),bias.stride(2), bias.stride(3)) else: - bias_strides = (0, 0, 0, 0) + stride_bz, stride_bh, stride_bm, stride_bn = (0, 0, 0, 0) attn_fwd[grid](q, k, v, bias, cache_seqlens, cache_batch_idx, descale_q, descale_k, descale_v, descale_o, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, - sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, - *bias_strides, stride_az, stride_ah, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, + sm_scale, softmax_lse, o, + stride_qb, stride_qh, stride_qm, stride_qd, + stride_kb, stride_kh, stride_kn, stride_kd, + stride_vb, stride_vh, stride_vn, stride_vd, + stride_ob, stride_oh, stride_om, stride_od, + stride_bz, stride_bh, stride_bm, stride_bn, + stride_az, stride_ah, + stride_sz, stride_sh, stride_sm, stride_sn, + stride_lse_z, stride_lse_h, stride_lse_m, + cu_seqlens_q, cu_seqlens_k, dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, IS_VARLEN=is_varlen, IS_INFERENCE=is_inference, + MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, IS_VARLEN=IS_VARLEN, IS_INFERENCE=is_inference, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT, FLIP_GRID=FLIP_GRID) diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py index baefb2410c1..6af99798ae9 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/fwd_ref.py @@ -1,8 +1,9 @@ import torch import math from typing import Literal, Optional -from .utils import DEBUG, compute_alibi_tensor_ref +from .utils import compute_alibi_tensor_ref +DEBUG = False DEBUG_CORE = False def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2): @@ -247,7 +248,7 @@ def attention_varlen_forward_pytorch_ref_impl( total_L_k = k.shape[0] o = torch.zeros((total_L_q, nheads_q, head_dim), dtype=q.dtype, device=q.device) - softmax_lse = torch.zeros((total_L_q, nheads_q), dtype=torch.float32, device=q.device) + softmax_lse = torch.zeros((nheads_q, total_L_q), dtype=torch.float32, device=q.device) sd_mask = torch.zeros((batch_size, nheads_q, max_seqlen_q, max_seqlen_k), dtype=torch.float32, device=q.device) # Compute group_size for MQA/GQA handling @@ -318,12 +319,11 @@ def attention_varlen_forward_pytorch_ref_impl( # Convert back to 'thd' layout o_i = o_i.permute(1, 0, 2) # [L_q_i, nheads_q, head_dim] - softmax_lse_i = softmax_lse_i.permute(1, 0) # [L_q_i, nheads_q] sd_mask_i = sd_mask_i # [nheads_q, L_q_i, L_k_i] # Place outputs in pre-allocated tensors o[start_q:end_q, :, :] = o_i - softmax_lse[start_q:end_q, :] = softmax_lse_i + softmax_lse[:, start_q:end_q] = softmax_lse_i sd_mask[i, :, :seqlen_q, :seqlen_k] = sd_mask_i return o, softmax_lse, sd_mask diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index cbe597c5dbe..9c10b7436c2 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -1,21 +1,20 @@ import torch import os from .fwd_prefill import attention_prefill_forward_triton_impl -from .bwd_prefill import attention_prefill_backward_triton_impl from .bwd_prefill_split import attention_prefill_backward_triton_split_impl -from .bwd_prefill_fused import _flash_attn_backward as attention_prefill_backward_triton_fused_impl -from .bwd_prefill_onekernel import attention_prefill_backward_triton_split_oneKernel_impl +from .bwd_prefill_fused_atomics import attention_prefill_backward_triton_fused_atomics_impl +from .bwd_prefill_fused_no_atomics import attention_prefill_backward_triton_split_fused_no_atomics_impl from .fwd_decode import attention_decode_forward_triton_impl from .fwd_ref import attention_forward_pytorch_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl -from .utils import DEBUG, USE_REF, MetaData, get_shapes_from_layout, is_fp8 +from .utils import DEBUG, USE_REF, MetaData, is_fp8 from einops import rearrange, repeat from flash_attn.layers.rotary import apply_rotary_emb from typing import Literal, Optional, Union USE_EXP2 = True -BWD_MODE = os.environ.get('BWD_MODE', 'jingning').lower() +BWD_MODE = os.environ.get('BWD_MODE', 'fused_no_atomics').lower() def fwd(q: torch.Tensor, k: torch.Tensor, @@ -67,8 +66,6 @@ def fwd(q: torch.Tensor, metadata.max_seqlens_q = q.shape[1] metadata.max_seqlens_k = k.shape[1] metadata.layout = "bshd" - if return_softmax: - metadata.return_scores = True # get shape batch, _ , nheads_q, _= q.shape @@ -137,7 +134,7 @@ def fwd(q: torch.Tensor, metadata.dropout_p, metadata.philox_seed, metadata.philox_offset, - metadata.return_scores, + metadata.return_softmax, USE_EXP2, descale_q, descale_k, @@ -153,6 +150,7 @@ def fwd(q: torch.Tensor, print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) print("softmax_lse:", softmax_lse, softmax_lse.shape) print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) + print("rng_state:", rng_state) return out, softmax_lse, sd_mask, rng_state @@ -302,8 +300,8 @@ def bwd( descale_dv, ) delta = delta_triton - elif BWD_MODE == "fused": - delta_triton = attention_prefill_backward_triton_fused_impl( + elif BWD_MODE == "fused_atomics": + delta_triton = attention_prefill_backward_triton_fused_atomics_impl( dout, q, k, @@ -330,8 +328,8 @@ def bwd( True, ) delta = delta_triton - elif BWD_MODE == "jingning": - delta_triton = attention_prefill_backward_triton_split_oneKernel_impl( + elif BWD_MODE == "fused_no_atomics": + delta_triton = attention_prefill_backward_triton_split_fused_no_atomics_impl( dout, q, k, @@ -436,8 +434,6 @@ def varlen_fwd( # Setup metadata metadata = MetaData(sm_scale=softmax_scale) - if return_softmax: - metadata.return_scores = True metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # set layout to "thd" and other metdata assert metadata.layout is not None @@ -509,7 +505,7 @@ def varlen_fwd( metadata.dropout_p, metadata.philox_seed, metadata.philox_offset, - metadata.return_scores, + metadata.return_softmax, USE_EXP2, descale_q, descale_k, @@ -677,8 +673,8 @@ def varlen_bwd( descale_dv, ) delta = delta_triton - elif BWD_MODE == "fused": - delta_triton = attention_prefill_backward_triton_fused_impl( + elif BWD_MODE == "fused_atomics": + delta_triton = attention_prefill_backward_triton_fused_atomics_impl( dout, q, k, @@ -705,8 +701,8 @@ def varlen_bwd( True, ) delta = delta_triton - elif BWD_MODE == "jingning": - delta_triton = attention_prefill_backward_triton_split_oneKernel_impl( + elif BWD_MODE == "fused_no_atomics": + delta_triton = attention_prefill_backward_triton_split_fused_no_atomics_impl( dout, q, k, @@ -900,7 +896,7 @@ def fwd_kvcache( metadata.dropout_p, metadata.philox_seed, metadata.philox_offset, - metadata.return_scores, + metadata.return_softmax, USE_EXP2, None, None, diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index ea82de065b5..f634103ca69 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -17,15 +17,18 @@ flash_attn_varlen_fp8_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func, - flash_attn_varlen_qkvpacked_fp8_func + flash_attn_varlen_qkvpacked_fp8_func, + flash_attn_with_kvcache ) -from .utils import DEBUG, input_helper, arch_supports_fp8 +from .utils import generate_bshd_kv_packed, generate_bshd_qkv_packed, generate_bshd_tensor, generate_varlen_kv_packed, generate_varlen_qkv_packed, input_helper, arch_supports_fp8, generate_varlen_tensor from .fwd_ref import attention_forward_pytorch_ref_impl from .fwd_prefill import attention_prefill_forward_triton_impl -from .bwd_prefill_onekernel import attention_prefill_backward_triton_split_oneKernel_impl +from .bwd_prefill_fused_no_atomics import attention_prefill_backward_triton_split_fused_no_atomics_impl from .bwd_ref import attention_backward_pytorch_ref_impl +DEBUG = False + # set print options # torch.set_printoptions(linewidth=5e5, edgeitems=10, sci_mode=False) # np.set_printoptions(linewidth=5000, threshold=1e4, suppress=True, precision=4) @@ -101,7 +104,7 @@ def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dr metadata.need_causal(True) # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - metadata.need_dropout(dropout_p) + metadata.need_dropout(dropout_p, True) # call Triton's forward implementation directly @@ -128,7 +131,7 @@ def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dr metadata.dropout_p, metadata.philox_seed, metadata.philox_offset, - metadata.return_scores, + metadata.return_softmax, use_exp2, None, None, @@ -164,7 +167,7 @@ def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dr print("Compare Triton Impl with refernce Pytorch Impl") # this can be set to true manually or when using dropout - if metadata.return_scores: + if metadata.return_softmax: if DEBUG: print("sd_mask_triton:", sd_mask_triton, sd_mask_triton.shape) print("sd_mask_ref:", sd_mask_ref, sd_mask_ref.shape) @@ -268,7 +271,7 @@ def test_op_prefill_bwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dr q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, dtype, layout=layout, device=device) # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - metadata.need_dropout(dropout_p) + metadata.need_dropout(dropout_p, True) # =============================================== Reference ============================================================== # fwd @@ -334,7 +337,7 @@ def test_op_prefill_bwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dr dq_triton = torch.zeros_like(q_triton, dtype=q.dtype) # NOTE: the kernel does inplace accumlation on dq so dq has to be zeros dk_triton = torch.zeros_like(k_triton, dtype=k.dtype) if DEBUG_INPUT else torch.empty_like(k_triton, dtype=k.dtype) dv_triton = torch.zeros_like(v_triton, dtype=v.dtype) if DEBUG_INPUT else torch.empty_like(v_triton, dtype=v.dtype) - delta_triton = attention_prefill_backward_triton_split_oneKernel_impl( + delta_triton = attention_prefill_backward_triton_split_fused_no_atomics_impl( do_triton, q_triton, k_triton, @@ -931,3 +934,189 @@ def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, for file, fp8_found in ttir_files_fp8_found_status.items(): assert fp8_found, f"{fp8_types} not found in {file}" + + +def clear_compile_cache(): + """Clear torch compile caches to prevent graph merging""" + if hasattr(torch._dynamo, 'reset'): + torch._dynamo.reset() + torch.cuda.synchronize() + + +@pytest.mark.parametrize( + "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + # (4, 8, 8, 128, 128, 64), # small test + (32, 32, 32, 531, 531, 128), # original test + # (16, 48, 16, 256, 512, 64), # MQA test (HQ > HK) + ], +) +def test_torch_compile(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD): + print(f"\n\nTesting with BATCH={BATCH}, HQ={HQ}, HK={HK}, N_CTX_Q={N_CTX_Q}, N_CTX_K={N_CTX_K}, D_HEAD={D_HEAD}") + + try: + # Test 1: flash_attn_func + print("\n1. Testing flash_attn_func...") + clear_compile_cache() + + q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD) + k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD) + v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD) + + flash_attn_func_compiled = torch.compile(flash_attn_func) + o = flash_attn_func_compiled(q, k, v, causal=True) + print(f"Output shape: {o.shape}, dtype: {o.dtype}") + o.sum().backward() + print("✓ flash_attn_func SUCCESS") + + # cleanup + del q, k, v, o + torch.cuda.empty_cache() + + + # Test 2: flash_attn_varlen_func + print("\n2. Testing flash_attn_varlen_func...") + clear_compile_cache() + + q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(BATCH * N_CTX_Q, HQ, D_HEAD, BATCH) + k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(BATCH * N_CTX_K, HK, D_HEAD, BATCH) + v, _, _ = generate_varlen_tensor(BATCH * N_CTX_K, HK, D_HEAD, BATCH) + + flash_attn_varlen_func_compiled = torch.compile(flash_attn_varlen_func) + o = flash_attn_varlen_func_compiled( + q, k, v, cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, causal=True + ) + print(f"Output shape: {o.shape}, dtype: {o.dtype}") + o.sum().backward() + print("✓ flash_attn_varlen_func SUCCESS") + + # cleanup + del q, k, v, o, cu_seqlens_q, cu_seqlens_k + torch.cuda.empty_cache() + + + # Test 3: flash_attn_qkvpacked_func + print("\n3. Testing flash_attn_qkvpacked_func...") + clear_compile_cache() + + qkv = generate_bshd_qkv_packed(BATCH, N_CTX_Q, HQ, D_HEAD) + + flash_attn_qkvpacked_func_compiled = torch.compile(flash_attn_qkvpacked_func) + o = flash_attn_qkvpacked_func_compiled(qkv, causal=True) + print(f"Output shape: {o.shape}, dtype: {o.dtype}") + o.sum().backward() + print("✓ flash_attn_qkvpacked_func SUCCESS") + + # cleanup + del qkv, o + torch.cuda.empty_cache() + + + # Test 4: flash_attn_varlen_qkvpacked_func + print("\n4. Testing flash_attn_varlen_qkvpacked_func...") + clear_compile_cache() + + total_q = BATCH * N_CTX_Q + qkv, cu_seqlens, max_seqlen = generate_varlen_qkv_packed(total_q, HQ, D_HEAD, BATCH) + + flash_attn_varlen_qkvpacked_func_compiled = torch.compile(flash_attn_varlen_qkvpacked_func) + o = flash_attn_varlen_qkvpacked_func_compiled( + qkv, cu_seqlens, max_seqlen, causal=True + ) + print(f"Output shape: {o.shape}, dtype: {o.dtype}") + o.sum().backward() + print("✓ flash_attn_varlen_qkvpacked_func SUCCESS") + + # cleanup + del qkv, o, cu_seqlens + torch.cuda.empty_cache() + + + # Test 5: flash_attn_kvpacked_func + print("\n5. Testing flash_attn_kvpacked_func...") + clear_compile_cache() + + q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD) + kv = generate_bshd_kv_packed(BATCH, N_CTX_K, HK, D_HEAD) + + flash_attn_kvpacked_func_compiled = torch.compile(flash_attn_kvpacked_func) + o = flash_attn_kvpacked_func_compiled(q, kv, causal=True) + print(f"Output shape: {o.shape}, dtype: {o.dtype}") + o.sum().backward() + print("✓ flash_attn_kvpacked_func SUCCESS") + + # cleanup + del q, kv, o + torch.cuda.empty_cache() + + + # Test 6: flash_attn_varlen_kvpacked_func + print("\n6. Testing flash_attn_varlen_kvpacked_func...") + clear_compile_cache() + + q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(BATCH * N_CTX_Q, HQ, D_HEAD, BATCH) + kv, cu_seqlens_k, max_seqlen_k = generate_varlen_kv_packed(BATCH * N_CTX_K, HK, D_HEAD, BATCH) + + flash_attn_varlen_kvpacked_func_compiled = torch.compile(flash_attn_varlen_kvpacked_func) + o = flash_attn_varlen_kvpacked_func_compiled( + q, kv, cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, causal=True + ) + print(f"Output shape: {o.shape}, dtype: {o.dtype}") + o.sum().backward() + print("✓ flash_attn_varlen_kvpacked_func SUCCESS") + + # cleanup + del q, kv, o, cu_seqlens_q, cu_seqlens_k + torch.cuda.empty_cache() + + + # Test 7: flash_attn_with_kvcache + print("\n7. Testing flash_attn_with_kvcache...") + clear_compile_cache() + + # setup cache dimensions + CACHE_SEQLEN = 1024 # max cache size + NEW_SEQLEN = 1 # for incremental decoding, usually 1 token at a time + + # create query for new tokens + q = generate_bshd_tensor(BATCH, NEW_SEQLEN, HQ, D_HEAD, dtype=torch.float16) + + # create kv cache using generators + k_cache = generate_bshd_tensor(BATCH, CACHE_SEQLEN, HK, D_HEAD, dtype=torch.float16) + v_cache = generate_bshd_tensor(BATCH, CACHE_SEQLEN, HK, D_HEAD, dtype=torch.float16) + + # cache sequence lengths + cache_seqlens = torch.full((BATCH,), 100, dtype=torch.int32, device='cuda') + + # new k,v to append to cache (optional) + k_new = generate_bshd_tensor(BATCH, NEW_SEQLEN, HK, D_HEAD, dtype=torch.float16) + v_new = generate_bshd_tensor(BATCH, NEW_SEQLEN, HK, D_HEAD, dtype=torch.float16) + + # Note: flash_attn_with_kvcache doesn't support backward pass + flash_attn_with_kvcache_compiled = torch.compile(flash_attn_with_kvcache) + + # Test with new k,v (append to cache and do attention) + with torch.no_grad(): + o = flash_attn_with_kvcache_compiled( + q, k_cache, v_cache, + k=k_new, v=v_new, + cache_seqlens=cache_seqlens, + causal=True + ) + print(f"Output shape (with new kv): {o.shape}, dtype: {o.dtype}") + + print("✓ flash_attn_with_kvcache SUCCESS") + + print("\n\n✅ ALL TESTS PASSED! ✅") + + except Exception as e: + print(f"\n❌ ERROR: {str(e)}") + # ensure we sync even on error to get proper error message + torch.cuda.synchronize() + raise e + finally: + # final cleanup + torch.cuda.empty_cache() + clear_compile_cache() diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index b129636e31d..1795e0d1366 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -45,7 +45,7 @@ class MetaData(): cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None cache_batch_idx = None packing: Optional[bool] = None - return_scores: bool = False + return_softmax: bool = False dropout_p: float = 0.0 philox_seed: Optional[int] = None philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. @@ -72,7 +72,7 @@ def __repr__(self) -> str: f" cache_seqlens={self.cache_seqlens},\n" f" cache_batch_idx={self.cache_batch_idx},\n" f" dropout_p={self.dropout_p},\n" - f" return_scores={self.return_scores}\n" + f" return_softmax={self.return_softmax}\n" f")") def __init__(self, sm_scale=1.0): @@ -113,7 +113,7 @@ def need_rotary(self, sin, cos, rotary_interleaved, rotary_conjunction=False): self.rotary_interleaved = rotary_interleaved self.rotary_conjunction = rotary_conjunction - def need_dropout(self, dropout_p, return_softmax = True): + def need_dropout(self, dropout_p, return_softmax): self.dropout_p = dropout_p self.return_softmax = return_softmax self.philox_seed, self.philox_offset = 0x1BF58, 0x1D4B49 @@ -129,7 +129,7 @@ def check_args(self, q, k, v, o): assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) # TODO: Remove once bias is supported with varlen assert self.bias is None - # assert not self.return_scores + # assert not self.return_softmax else: assert q.dim() == 4 assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 @@ -165,7 +165,7 @@ def generate_varlen_tensor( batch_size: Optional[int] = None, equal_seqlens: bool = False, device: str = "cuda", - dtype: torch.dtype = torch.float32, + dtype: torch.dtype = torch.float16, DEBUG_INPUT: bool = False ): if DEBUG: @@ -225,7 +225,7 @@ def generate_varlen_tensor( x.requires_grad_() return x, cu_seqlens, max_seqlen -def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False): +def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", DEBUG_INPUT=False): # save fp8 type is_fp8_dtype = is_dtype_fp8(dtype) if is_fp8_dtype: @@ -248,7 +248,7 @@ def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype, device="cuda" x.requires_grad_() return x -def generate_bhsd_tensor(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False): +def generate_bhsd_tensor(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", DEBUG_INPUT=False): # save fp8 type is_fp8_dtype = is_dtype_fp8(dtype) if is_fp8_dtype: @@ -272,6 +272,235 @@ def generate_bhsd_tensor(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype, device="cuda" x.requires_grad_() return x +def generate_bshd_qkv_packed(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", DEBUG_INPUT=False): + """Generate QKV packed tensor with shape (BATCH, SEQ_LEN, 3, NUM_HEADS, D_HEAD)""" + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # gen tensor + tensor_shape = (BATCH, SEQ_LEN, 3, NUM_HEADS, D_HEAD) + if DEBUG_INPUT: + x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, SEQ_LEN, 1, 1, 1).expand(*tensor_shape).contiguous() + else: + x = torch.randn(tensor_shape, dtype=dtype, device=device) + + if is_fp8_dtype: + # cast to fp8 - need to handle the packed dimension + raise NotImplementedError("FP8 not supported for QKV packing yet") + else: + x.requires_grad_() + return x + + +def generate_bshd_kv_packed(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", DEBUG_INPUT=False): + """Generate KV packed tensor with shape (BATCH, SEQ_LEN, 2, NUM_HEADS, D_HEAD)""" + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # gen tensor + tensor_shape = (BATCH, SEQ_LEN, 2, NUM_HEADS, D_HEAD) + if DEBUG_INPUT: + x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, SEQ_LEN, 1, 1, 1).expand(*tensor_shape).contiguous() + else: + x = torch.randn(tensor_shape, dtype=dtype, device=device) + + if is_fp8_dtype: + # cast to fp8 - need to handle the packed dimension + raise NotImplementedError("FP8 not supported for KV packing yet") + else: + x.requires_grad_() + return x + + +def generate_bhsd_qkv_packed(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", DEBUG_INPUT=False): + """Generate QKV packed tensor with shape (BATCH, 3, NUM_HEADS, SEQ_LEN, D_HEAD)""" + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # gen tensor + tensor_shape = (BATCH, 3, NUM_HEADS, SEQ_LEN, D_HEAD) + if DEBUG_INPUT: + x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, 1, 1, SEQ_LEN, 1).expand(*tensor_shape).contiguous() + else: + x = torch.randn(tensor_shape, dtype=dtype, device=device) + + if is_fp8_dtype: + # cast to fp8 - need to handle the packed dimension + raise NotImplementedError("FP8 not supported for QKV packing yet") + else: + x.requires_grad_() + return x + + +def generate_bhsd_kv_packed(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", DEBUG_INPUT=False): + """Generate KV packed tensor with shape (BATCH, 2, NUM_HEADS, SEQ_LEN, D_HEAD)""" + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # gen tensor + tensor_shape = (BATCH, 2, NUM_HEADS, SEQ_LEN, D_HEAD) + if DEBUG_INPUT: + x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, 1, 1, SEQ_LEN, 1).expand(*tensor_shape).contiguous() + else: + x = torch.randn(tensor_shape, dtype=dtype, device=device) + + if is_fp8_dtype: + # cast to fp8 - need to handle the packed dimension + raise NotImplementedError("FP8 not supported for KV packing yet") + else: + x.requires_grad_() + return x + + +def generate_varlen_qkv_packed( + total_seqlen: int, + num_heads: int, + head_size: int, + batch_size: Optional[int] = None, + equal_seqlens: bool = False, + device: str = "cuda", + dtype: torch.dtype = torch.float16, + DEBUG_INPUT: bool = False +): + """Generate varlen QKV packed tensor with shape (total_seqlen, 3, num_heads, head_size)""" + if DEBUG: + print("generate_varlen_qkv_packed") + print("total_seqlen", total_seqlen) + print("num_heads", num_heads) + print("head_size", head_size) + + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # get valid batch_size + if batch_size is None: + valid_batch_sizes = [bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen] + batch_size = random.choice(valid_batch_sizes) + + # get seqlens + if equal_seqlens: + seqlens = torch.full( + (batch_size,), + total_seqlen // batch_size, + dtype=torch.int32, + device=device + ) + seqlens[-1] += total_seqlen % batch_size + else: + seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) + + # create cumulative sequence lengths + cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)]).to(torch.int32).to(device=device) + max_seqlen = torch.max(seqlens).to(torch.int32).item() + + # create varlen qkv packed tensor + if DEBUG_INPUT: + x = torch.zeros(total_seqlen, 3, num_heads, head_size, dtype=dtype, device=device) + for i in range(batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i+1].item() + length = end - start + + x[start:end, :, :, :] = ( + torch.arange(length, dtype=dtype, device=device) + .view(length, 1, 1, 1) + .expand(length, 3, num_heads, head_size) + ) + else: + x = torch.randn((total_seqlen, 3, num_heads, head_size), dtype=dtype, device=device) + + if is_fp8_dtype: + # cast to fp8 - need to handle the packed dimension + raise NotImplementedError("FP8 not supported for QKV packing yet") + else: + x.requires_grad_() + return x, cu_seqlens, max_seqlen + + +def generate_varlen_kv_packed( + total_seqlen: int, + num_heads: int, + head_size: int, + batch_size: Optional[int] = None, + equal_seqlens: bool = False, + device: str = "cuda", + dtype: torch.dtype = torch.float16, + DEBUG_INPUT: bool = False +): + """Generate varlen KV packed tensor with shape (total_seqlen, 2, num_heads, head_size)""" + if DEBUG: + print("generate_varlen_kv_packed") + print("total_seqlen", total_seqlen) + print("num_heads", num_heads) + print("head_size", head_size) + + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # get valid batch_size + if batch_size is None: + valid_batch_sizes = [bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen] + batch_size = random.choice(valid_batch_sizes) + + # get seqlens + if equal_seqlens: + seqlens = torch.full( + (batch_size,), + total_seqlen // batch_size, + dtype=torch.int32, + device=device + ) + seqlens[-1] += total_seqlen % batch_size + else: + seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) + + # create cumulative sequence lengths + cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)]).to(torch.int32).to(device=device) + max_seqlen = torch.max(seqlens).to(torch.int32).item() + + # create varlen kv packed tensor + if DEBUG_INPUT: + x = torch.zeros(total_seqlen, 2, num_heads, head_size, dtype=dtype, device=device) + for i in range(batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i+1].item() + length = end - start + + x[start:end, :, :, :] = ( + torch.arange(length, dtype=dtype, device=device) + .view(length, 1, 1, 1) + .expand(length, 2, num_heads, head_size) + ) + else: + x = torch.randn((total_seqlen, 2, num_heads, head_size), dtype=dtype, device=device) + + if is_fp8_dtype: + # cast to fp8 - need to handle the packed dimension + raise NotImplementedError("FP8 not supported for KV packing yet") + else: + x.requires_grad_() + return x, cu_seqlens, max_seqlen + +# Replace the existing input_helper function in utils.py with this updated version + def input_helper( BATCH: int, HQ: int, @@ -294,20 +523,42 @@ def input_helper( # set params TOTAL_SEQLENS_Q = BATCH * N_CTX_Q TOTAL_SEQLENS_K = BATCH * N_CTX_K - equal_seqlens=False + equal_seqlens = False - # gen tensors - # TODO: the gen functions should maybe have different gen modes like random, ones, increasing seqlen - if is_fp8_dtype: - q, cu_seqlens_q, max_seqlen_q, descale_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - k, cu_seqlens_k, max_seqlen_k, descale_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - v, _, _ , descale_v = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - do, _, _ , descale_do = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) - else: - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - v, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + # deal with packing + if packing is None: + # gen tensors + if is_fp8_dtype: + q, cu_seqlens_q, max_seqlen_q, descale_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + k, cu_seqlens_k, max_seqlen_k, descale_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + v, _, _, descale_v = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + do, _, _, descale_do = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + else: + q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + v, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + elif packing == "kv": + # gen tensors with kv packing + if is_fp8_dtype: + raise ValueError("FP8 not supported for KV packing yet") + else: + q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + kv, cu_seqlens_k, max_seqlen_k = generate_varlen_kv_packed(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + elif packing == "qkv": + # qkv packing - requires same sequence length for q and k + assert N_CTX_Q == N_CTX_K, "For QKV packing, Q and K must have same sequence length" + assert HQ == HK, "For QKV packing, Q and K must have same number of heads" + + if is_fp8_dtype: + raise ValueError("FP8 not supported for QKV packing yet") + else: + qkv, cu_seqlens_q, max_seqlen_q = generate_varlen_qkv_packed(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + cu_seqlens_k = cu_seqlens_q + max_seqlen_k = max_seqlen_q + # create dummy do for qkv case + do = torch.ones((TOTAL_SEQLENS_Q, HQ, D_HEAD), dtype=dtype, device=device) if DEBUG_INPUT else torch.randn((TOTAL_SEQLENS_Q, HQ, D_HEAD), dtype=dtype, device=device) # setup metadata if DEBUG_INPUT: @@ -317,31 +568,61 @@ def input_helper( metadata = MetaData(sm_scale=sm_scale) metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) metadata.need_causal(CAUSAL) - metadata.need_dropout(DROPOUT_P) + metadata.need_dropout(DROPOUT_P, True) + elif layout == 'bshd' or layout == "bhsd": - # gen tensors - if layout == "bshd": + # deal with packing + if packing is None: + # gen tensors + if layout == "bshd": + if is_fp8_dtype: + q, descale_q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + k, descale_k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + v, descale_v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do, descale_do = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) + else: + q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + elif layout == "bhsd": + if is_fp8_dtype: + q, descale_q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + k, descale_k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + v, descale_v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do, descale_do = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) + else: + q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + elif packing == "kv": + # gen tensors with kv packing if is_fp8_dtype: - q, descale_q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - k, descale_k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - v, descale_v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do, descale_do = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) + raise ValueError("FP8 not supported for KV packing yet") else: - q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) - elif layout == "bhsd": + if layout == "bshd": + q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + kv = generate_bshd_kv_packed(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + elif layout == "bhsd": + q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + kv = generate_bhsd_kv_packed(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + elif packing == "qkv": + # qkv packing - requires same sequence length for q and k + assert N_CTX_Q == N_CTX_K, "For QKV packing, Q and K must have same sequence length" + assert HQ == HK, "For QKV packing, Q and K must have same number of heads" + if is_fp8_dtype: - q, descale_q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - k, descale_k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - v, descale_v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do, descale_do = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) + raise ValueError("FP8 not supported for QKV packing yet") else: - q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + if layout == "bshd": + qkv = generate_bshd_qkv_packed(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones((BATCH, N_CTX_Q, HQ, D_HEAD), dtype=dtype, device=device) if DEBUG_INPUT else torch.randn((BATCH, N_CTX_Q, HQ, D_HEAD), dtype=dtype, device=device) + elif layout == "bhsd": + qkv = generate_bhsd_qkv_packed(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones((BATCH, HQ, N_CTX_Q, D_HEAD), dtype=dtype, device=device) if DEBUG_INPUT else torch.randn((BATCH, HQ, N_CTX_Q, D_HEAD), dtype=dtype, device=device) # setup metadata if DEBUG_INPUT: @@ -353,42 +634,22 @@ def input_helper( metadata.max_seqlens_k = N_CTX_K metadata.layout = layout metadata.need_causal(CAUSAL) - metadata.need_dropout(DROPOUT_P) + metadata.need_dropout(DROPOUT_P, True) else: raise ValueError(f"Unknown layout: {layout}") - # deal with packing + # return based on packing if packing is None: if is_fp8_dtype: return (q, descale_q), (k, descale_k), (v, descale_v), (do, descale_do), metadata else: return q, k, v, do, metadata elif packing == "kv": - # pack k and v - if layout in ["bhsd", "thd"]: - kv = torch.stack([k, v], dim=1) - elif layout == "bshd": - kv = torch.stack([k, v], dim=2) - else: - raise ValueError(f"Unknown layout: {layout}") - if is_fp8_dtype: raise ValueError("FP8 not supported kv packing yet") else: return q, kv, do, metadata elif packing == "qkv": - # qkv packing - requires same sequence length for q and k - assert N_CTX_Q == N_CTX_K, "For QKV packing, Q and K must have same sequence length" - assert HQ == HK, "For QKV packing, Q and K must have same number of heads" - - # pack q, k, and v - if layout in ["bhsd", "thd"]: - qkv = torch.stack([q, k, v], dim=1) - elif layout == "bshd": - qkv = torch.stack([q, k, v], dim=2) - else: - raise ValueError(f"Unknown layout: {layout}") - if is_fp8_dtype: raise ValueError("FP8 not supported qkv packing yet") else: @@ -693,6 +954,9 @@ def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k): relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) +def round_multiple(x, m): + return (x + m - 1) // m * m + # ------------------------------- # Dropouts # ------------------------------- From d0c3cbcb06bcc68ec27a0cfaa5adf57fc8971312 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 10 Jul 2025 14:15:29 -0400 Subject: [PATCH 05/27] Sliding Window Forward (#151) * Compress SWA work test case set up debug inputs add fwd ref one mask ref fwd first pass save ref doesnot work for bigger seqlens save new version some causal cases failing found bad cases working new attn new atten works new attn_fwd works reorg n_extra_tokens use seqlen_delta_qk ref fwd works add sliding window to bwd ref test kvcache decode ref work with everything except sliding window add debug code for 12 failing sliding window cases for decode attention_decode_forward_ref_impl mostly works except for alibi fix alibi in attention_decode_forward_ref_impl ref works with normal, varlen & kvcache move stuff around figure out masking old attn inner two inner functions remove load_fn do Lk - Lq like ref unify IS_CAUSAL code in epilogue clean up add args rm inference stuff simplify compute_masking simpler compute mask stub out returning front masking variables remove pointer pass compute ptrs inloop compute block min and max window stub inside inner mask loop trying to use attn_fwd_mask causes issues fix compiler bug when front masking gen specifc types add sliding window and debug statements use identity for v add more taste cases add comments save use k_max_token for clarity disable debug configs basic NON-CAUSAL SLIDING WINDOW non causal sliding window works on the all the shapes non sliding window working in fwd clean up fused bwd seperate old fwd_prefill move configs to utils.py * fix bwd ref bug * skip local cases so that fa output * no sliding window causal green * add backward test skip for sliding window * clean reduce in fwd_kvcache. no is_CASUAL branching * add kvcache masking * kvcache working * fix some bugs in test.py * clean up --- flash_attn/flash_attn_triton_amd/bwd_ref.py | 116 ++- .../flash_attn_triton_amd/fwd_decode.py | 120 ++- .../flash_attn_triton_amd/fwd_prefill.py | 975 ++++++++++++------ flash_attn/flash_attn_triton_amd/fwd_ref.py | 328 +++++- .../flash_attn_triton_amd/interface_fa.py | 86 +- flash_attn/flash_attn_triton_amd/test.py | 354 +------ flash_attn/flash_attn_triton_amd/utils.py | 269 +++-- tests/test_flash_attn_triton_amd.py | 58 +- 8 files changed, 1407 insertions(+), 899 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd_ref.py b/flash_attn/flash_attn_triton_amd/bwd_ref.py index 8bdccb1d329..cb1637157a0 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/bwd_ref.py @@ -7,7 +7,8 @@ DEBUG_CORE = False def attention_backward_core_ref_impl( - do, q, k, v, o, softmax_lse, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2 + do, q, k, v, o, softmax_lse, sm_scale, causal, window_size_left, window_size_right, + dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2 ): if DEBUG_CORE: print() @@ -16,10 +17,12 @@ def attention_backward_core_ref_impl( print("q:", q, q.shape) print("k:", k, k.shape) print("v:", v, v.shape) - print("o:", o, o.shape) # is a bad number + print("o:", o, o.shape) print("softmax_lse:", softmax_lse, softmax_lse.shape) print("sm_scale:", sm_scale) print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) print("dropout_p:", dropout_p) print("philox_seed:", philox_seed) print("philox_offset:", philox_offset) @@ -33,7 +36,6 @@ def attention_backward_core_ref_impl( o = o.to(torch.float32) softmax_lse = softmax_lse.to(torch.float32) - # recompute attention_scores. Make sure it matches the forward impl. i.e. It use float32 attention_scores = torch.matmul(q, k.transpose(-2, -1)) if DEBUG_CORE: @@ -50,51 +52,95 @@ def attention_backward_core_ref_impl( print("alibi_slopes:", alibi_slopes, alibi_slopes.shape) alibi_bias = compute_alibi_tensor_ref(alibi_slopes, L_q, L_k) alibi_bias = alibi_bias.reshape(-1, L_q, L_k) - if True: + if DEBUG_CORE: print("alibi_bias:", alibi_bias, alibi_bias.shape) attention_scaled_scores = attention_scaled_scores + alibi_bias if DEBUG_CORE: print("attention_scaled_scores after alibi:", attention_scaled_scores, attention_scaled_scores.shape) - # Apply causal mask if necessary - if causal: - L_q, L_k = q.shape[1], k.shape[1] - row_idx = torch.arange(L_q, device=q.device).unsqueeze(1) - col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) - col_offset = L_q-L_k - causal_mask = row_idx >= (col_offset + col_idx) + # Apply masks + L_q, L_k = q.shape[1], k.shape[1] + row_idx = torch.arange(L_q, device=q.device).unsqueeze(1) + col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) + col_offset = L_k - L_q + + mask_applied = False + if causal and (window_size_left, window_size_right) == (-1, -1): + # Pure causal: ensure query doesn't attend to future keys + mask = row_idx >= (col_idx - col_offset) + mask_applied = True + if DEBUG_CORE: + print("causal_mask:", mask) + elif (window_size_left, window_size_right) != (-1, -1): + # Handle the case where window sizes exceed sequence length + if window_size_left >= L_k: + window_size_left = -1 # No left limit + if window_size_right >= L_k: + window_size_right = -1 # No right limit + + if causal: + # Causal + sliding window: ensure we don't attend to future + window_size_right = min(window_size_right, 0) if window_size_right != -1 else 0 + + # Create sliding window mask + # Each query at position i attends to keys in [i + offset - left, i + offset + right] + if window_size_left == -1 and window_size_right == -1: + # No window restriction + mask = torch.ones((L_q, L_k), dtype=torch.bool, device=q.device) + else: + mask = torch.ones((L_q, L_k), dtype=torch.bool, device=q.device) + if window_size_left != -1: + # Each query at position i attends to keys from position (i - left) accounting for offset + mask = mask & (col_idx >= (row_idx + col_offset - window_size_left)) + if window_size_right != -1: + # Each query at position i attends to keys up to position (i + right) accounting for offset + mask = mask & (col_idx <= (row_idx + col_offset + window_size_right)) + + # Apply causal constraint + if causal: + causal_mask = row_idx >= (col_idx - col_offset) + mask = mask & causal_mask + + mask_applied = True if DEBUG_CORE: - print("causal_mask:", causal_mask) - # set -inf to places the causal mask is false + print(f"sliding_window_mask (left={window_size_left}, right={window_size_right}):", mask) + + # Apply the mask if created + if mask_applied: attention_scaled_scores = attention_scaled_scores.masked_fill( - torch.logical_not(causal_mask.unsqueeze(0)), float('-inf') + torch.logical_not(mask.unsqueeze(0)), float('-inf') ) if DEBUG_CORE: - print("attention_scaled_scores after causal:", attention_scaled_scores, attention_scaled_scores.shape) + print("attention_scaled_scores after masking:", attention_scaled_scores, attention_scaled_scores.shape) # compute probabilities using softmax_lse if use_exp2: RCP_LN = 1 / math.log(2) attention_scaled_scores_base2 = attention_scaled_scores * RCP_LN softmax_lse_base2 = softmax_lse * RCP_LN - softmax_lse_3d = softmax_lse_base2.unsqueeze(-1) + softmax_lse_3d = softmax_lse_base2.unsqueeze(-1) p = torch.exp2(attention_scaled_scores_base2 - softmax_lse_3d) else: - softmax_lse_3d = softmax_lse.unsqueeze(-1) + softmax_lse_3d = softmax_lse.unsqueeze(-1) p = torch.exp(attention_scaled_scores - softmax_lse_3d) + + # Zero out positions outside the mask + if mask_applied: + p = p.masked_fill(torch.logical_not(mask.unsqueeze(0)), 0.0) + if DEBUG_CORE: print("softmax_lse_3d:", softmax_lse_3d, softmax_lse_3d.shape) print("p:", p, p.shape) if dropout_p > 0.0: rand_vals = torch.rand(p.shape, generator=torch.Generator(device=p.device).manual_seed(philox_seed), device=p.device, dtype=p.dtype) - dropout_mask, dropout_scale = rand_vals > dropout_p, (1.0 / (1 - dropout_p)) - if DEBUG: + dropout_mask, dropout_scale = rand_vals > dropout_p, (1.0 / (1 - dropout_p)) + if DEBUG_CORE: print("dropout_scale:", dropout_scale) print("dropout_mask:", dropout_mask) p_drop = torch.where(dropout_mask, p, torch.zeros_like(p)) - p_drop_scaled = p_drop * dropout_scale + p_drop_scaled = p_drop * dropout_scale if DEBUG_CORE: print("dropout_scale:", dropout_scale) print("p_drop:", p_drop, p_drop.shape) @@ -107,7 +153,7 @@ def attention_backward_core_ref_impl( # compute dp dp_dropout = torch.matmul(do, v.transpose(-2, -1)) - dp = torch.where(dropout_mask, dp_dropout , torch.zeros_like(dp_dropout)) * dropout_scale + dp = torch.where(dropout_mask, dp_dropout, torch.zeros_like(dp_dropout)) * dropout_scale if DEBUG_CORE: print("dp_dropout:", dp_dropout, dp_dropout.shape) print("dp:", dp, dp.shape) @@ -127,9 +173,14 @@ def attention_backward_core_ref_impl( delta = torch.sum(o * do, axis=-1).unsqueeze(-1) else: delta = torch.sum(p * dp, axis=-1).unsqueeze(-1) - if DEBUG: + if DEBUG_CORE: print("delta:", delta, delta.shape) dscores_scaled = p * (dp - delta) + + # Zero out gradients for positions outside the mask + if mask_applied: + dscores_scaled = dscores_scaled.masked_fill(torch.logical_not(mask.unsqueeze(0)), 0.0) + ds = dscores_scaled * sm_scale if DEBUG_CORE: print("dscores_scaled:", dscores_scaled, dscores_scaled.shape) @@ -167,6 +218,8 @@ def attention_varlen_backward_pytorch_ref_impl( softmax_lse, sm_scale, causal, + window_size_left, + window_size_right, layout, cu_seqlens_q, cu_seqlens_k, @@ -255,6 +308,8 @@ def attention_varlen_backward_pytorch_ref_impl( softmax_lse_i, sm_scale, causal, + window_size_left, + window_size_right, dropout_p, philox_seed, philox_offset, @@ -270,7 +325,8 @@ def attention_varlen_backward_pytorch_ref_impl( if group_size != 1: # Reshape dq_i and delta_i back to original shape dq_i = dq_i.view(dq_i.shape[0], nheads_k, group_size, head_dim) - delta_i = delta_i.view(delta_i.shape[0], nheads_k, group_size) + L_q_i = delta_i.shape[1] + delta_i = delta_i.view(nheads_k, group_size, L_q_i) # Sum dk_i and dv_i over group dimension dk_i = dk_i.view(dk_i.shape[0], nheads_k, group_size, head_dim) dv_i = dv_i.view(dv_i.shape[0], nheads_k, group_size, head_dim) @@ -278,7 +334,7 @@ def attention_varlen_backward_pytorch_ref_impl( dv_i = dv_i.sum(dim=2) # Reshape dq_i back to [L_q_i, nheads_q, head_dim] dq_i = dq_i.reshape(dq_i.shape[0], nheads_q, head_dim) - delta_i = delta_i.reshape(delta_i.shape[0], nheads_q) + delta_i = delta_i.reshape(nheads_q, L_q_i) else: # No need to reshape pass @@ -300,6 +356,8 @@ def attention_vanilla_backward_pytorch_ref_impl( softmax_lse, sm_scale, causal, + window_size_left, + window_size_right, layout, dropout_p, philox_seed, @@ -366,6 +424,8 @@ def attention_vanilla_backward_pytorch_ref_impl( softmax_lse, sm_scale, causal, + window_size_left, + window_size_right, dropout_p, philox_seed, philox_offset, @@ -421,6 +481,8 @@ def attention_backward_pytorch_ref_impl( sm_scale: float, alibi_slopes: Optional[torch.Tensor], causal: bool, + window_size_left: int, + window_size_right: int, layout: Literal["bshd", "bhsd", "thd"], cu_seqlens_q: Optional[torch.Tensor], cu_seqlens_k: Optional[torch.Tensor], @@ -441,6 +503,8 @@ def attention_backward_pytorch_ref_impl( softmax_lse, sm_scale, causal, + window_size_left, + window_size_right, layout, cu_seqlens_q, cu_seqlens_k, @@ -462,6 +526,8 @@ def attention_backward_pytorch_ref_impl( softmax_lse, sm_scale, causal, + window_size_left, + window_size_right, layout, dropout_p, philox_seed, @@ -476,4 +542,4 @@ def attention_backward_pytorch_ref_impl( dk.copy_(dk_ref.to(dk.dtype)) dq.copy_(dq_ref.to(dq.dtype)) - return delta + return delta \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index e165d714876..d3f7f9c32b9 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -133,6 +133,9 @@ def _fwd_kernel_splitK( USE_ALIBI: tl.constexpr, PADDED_HEAD: tl.constexpr, GROUP_SIZE: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, ): # get program ids pid_m = tl.program_id(0) @@ -297,35 +300,62 @@ def _fwd_kernel_splitK( alibi_bias = -1 * alibi_slope * relative_pos qk += (alibi_bias * 1.44269504) - # Apply causal mask if IS_CAUSAL is True - if IS_CAUSAL: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = start_n + tl.arange(0, BLOCK_N) - - # create a N_CTX_Q x kv_len causal mask - col_offset = N_CTX_Q - N_CTX_K_FINAL - causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) + # ------------------------------------------------------------------ + # masking + # ------------------------------------------------------------------ + if USE_SLIDING_WINDOW: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # q positions + col_idx = start_n + tl.arange(0, BLOCK_N) # k positions + row = row_idx[:, None] # [M,1] + col = col_idx[None, :] # [1,N] + + if IS_CAUSAL: + # -------- causal + window -------- + diag = N_CTX_K_FINAL - N_CTX_Q # sk-sq + causal_ok = col <= row + diag + if WINDOW_SIZE_LEFT < 0: # only right window + win_ok = col <= row + diag + WINDOW_SIZE_RIGHT + else: # both sides + win_ok = ((col >= row + diag - WINDOW_SIZE_LEFT) & + (col <= row + diag + WINDOW_SIZE_RIGHT)) + mask = ~(causal_ok & win_ok) # True ⇒ -inf + else: + # -------- non-causal window -------- + sk, sq = N_CTX_K_FINAL, N_CTX_Q + if WINDOW_SIZE_LEFT < 0: + mask = col > row + (sk - sq) + WINDOW_SIZE_RIGHT + else: + right = tl.minimum(row + (sk - sq) + WINDOW_SIZE_RIGHT, sk) + left = row + (sk - sq) - WINDOW_SIZE_LEFT + mask = (col > right) | (col < left) + qk = tl.where(mask, float("-inf"), qk) + else: + if IS_CAUSAL: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = start_n + tl.arange(0, BLOCK_N) - # Apply the mask - qk = tl.where(causal_mask, qk, float("-inf")) + # create a N_CTX_Q x kv_len causal mask + col_offset = N_CTX_Q - N_CTX_K_FINAL + causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) + + # Apply the mask + qk = tl.where(causal_mask, qk, float("-inf")) # TODO: This is slow, and only needed at the last iteration. # Maybe we can unroll the last iteration instead? if BOUNDS_CHECKS_N: qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) - # -- compute scaling constant --- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - if IS_CAUSAL: - alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf"))) - else: - alpha = tl.math.exp2(m_i - m_i_new) - # cause of nan because subtracting infs - if IS_CAUSAL: - qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf")) - else: - qk = qk - m_i_new[:, None] - + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) # per-row max so far + + # rows that are *all* -inf after masking + valid = m_i_new > float("-inf") + + # scale previous partial sums safely + alpha = tl.where(valid, tl.math.exp2(m_i - m_i_new), 0.0) + + # subtract the row max only on valid rows + qk = tl.where(valid[:, None], qk - m_i_new[:, None], float("-inf")) p = tl.math.exp2(qk) # -- update m_i and l_i -- @@ -387,7 +417,6 @@ def _splitK_reduce( split_k: tl.constexpr, splitK_pow2: tl.constexpr, MASK_SPLITK: tl.constexpr, - IS_CAUSAL: tl.constexpr, PADDED_HEAD: tl.constexpr, ): # get pids @@ -426,23 +455,15 @@ def _splitK_reduce( g_m = tl.max(l_m, axis=0) - if IS_CAUSAL: - l_m_offset = l_m - g_m - alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0) - else: - alpha = tl.math.exp2(l_m - g_m) + alpha = tl.where(l_m > float("-inf"), tl.math.exp2(l_m - g_m), 0.0) # read sum l_sum *= alpha g_sum = tl.sum(l_sum, axis=0) acc = acc * alpha[:, None] - if IS_CAUSAL: - # Avoid division by zero - g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0) - acc_out = tl.sum(acc, axis=0) / g_sum_safe - else: - acc_out = tl.sum(acc, axis=0) / g_sum + g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0) + acc_out = tl.sum(acc, axis=0) / g_sum_safe # Store output z_id = pid_zhg // (H * G) @@ -454,11 +475,10 @@ def _splitK_reduce( # Store lse l_ptrs = LSE + pid_zhg * stride_lse_zhg + pid_m - if IS_CAUSAL: - lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m) - tl.store(l_ptrs, lse) - else: - tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + lse_val = tl.where(g_sum > 0, + (g_m + tl.math.log2(g_sum)) / 1.44269504, + g_m) + tl.store(l_ptrs, lse_val) @triton.jit @@ -571,10 +591,12 @@ def attention_decode_forward_triton_impl( v_new: Optional[torch.Tensor], out: torch.Tensor, sm_scale: float, - causal: bool, + causal: bool, + window_size_left: int, + window_size_right: int, alibi_slopes: Optional[torch.Tensor], layout: Literal["bshd"], - cache_seqlens: Optional[Union[(int, torch.Tensor)]], + cache_seqlens: Optional[torch.Tensor], cache_batch_idx: Optional[torch.Tensor], ): # triton configs @@ -586,8 +608,9 @@ def attention_decode_forward_triton_impl( # kernel_configs is_new_kv = True if k_new is not None and v_new is not None else False - use_alibi = False if alibi_slopes is None else True + use_alibi, (stride_az, stride_ah) = True if alibi_slopes is not None else False, alibi_slopes.stride() if alibi_slopes is not None else (None, None) use_cache_seqlens = cache_seqlens is not None + use_sliding_window = window_size_left != -1 or window_size_right != -1 SPLIT_K = None NUM_QUANT_GROUPS = 1 @@ -602,11 +625,6 @@ def attention_decode_forward_triton_impl( ( _, seqlen_kn, nheads_kn, dim_kn), (stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d) = (None, None, None, None), (None, None, None, None) (_, seqlen_vn, nheads_vn, dim_vn), (stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d) = (None, None, None, None), (None, None, None, None) (_, seqlen_o, nheads_o, dim_o), (stride_oz, stride_oh, stride_om, stride_od) = get_shape_and_strides_from_layout(out, layout) - if use_alibi: - stride_az, stride_ah = alibi_slopes.stride() - else: - stride_az, stride_ah = (None, None) - assert dim_q == dim_kc == dim_vc, f"Dimensions must match: {dim_q}, {dim_kc}, {dim_vc}" # add extra information needed by the kernels @@ -656,7 +674,7 @@ def attention_decode_forward_triton_impl( stride_mzhg, stride_m2, stride_ms, stride_mm = metadata.stride() stride_lse_zhg, stride_lse_m = lse.stride() - if False: + if DEBUG: print("batch_size, seqlen_q, nheads_q, dim_q", (batch_size, seqlen_q, nheads_q, dim_q)) print("_, seqlen_kc, nheads_kc, dim_kc", (_, seqlen_kc, nheads_kc, dim_kc)) print("dim_padded:", dim_padded) @@ -748,6 +766,9 @@ def attention_decode_forward_triton_impl( USE_ALIBI=use_alibi, PADDED_HEAD=is_padded_head, GROUP_SIZE=group_size, + USE_SLIDING_WINDOW=use_sliding_window, + WINDOW_SIZE_LEFT=window_size_left, + WINDOW_SIZE_RIGHT=window_size_right, num_warps=num_warps_fwd, num_stages=num_stages, ) @@ -809,8 +830,7 @@ def attention_decode_forward_triton_impl( split_k=split_k, splitK_pow2=splitK_pow2, MASK_SPLITK=mask_split_k, - IS_CAUSAL=causal, PADDED_HEAD=is_padded_head, num_warps=num_warps_reduce) - return lse + return lse \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index e33982bb6a7..59fe8bfaf4e 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -1,8 +1,9 @@ +import os import torch import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_arch, is_cdna, is_fp8, is_rdna, create_dropout_mask +from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_arch, is_cdna, is_fp8, is_rdna, create_dropout_mask, get_fwd_prefill_autotune_configs DEBUG = False @@ -10,70 +11,192 @@ tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) -# Convenience function to load with optional boundary checks. -# "First" is the major dim, "second" is the minor dim. +fwd_prefill_autotune_configs, fwd_prefill_autotune_keys = get_fwd_prefill_autotune_configs() + @triton.jit -def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): - if offset_first is not None and offset_second is not None: - mask = (offset_first[:, None] < boundary_first) & \ - (offset_second[None, :] < boundary_second) - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_first is not None: - mask = offset_first[:, None] < boundary_first - tensor = tl.load(ptrs, mask=mask, other=0.0) - elif offset_second is not None: - mask = offset_second[None, :] < boundary_second - tensor = tl.load(ptrs, mask=mask, other=0.0) - else: - tensor = tl.load(ptrs) - return tensor +def _attn_fwd_no_mask(acc, l_i, m_i, + q, k_base_ptrs, v_base_ptrs, bias_base_ptrs, + stride_kn, stride_vk, stride_bn, stride_sn, + start_m, seqlen_k, seqlen_q, + dropout_p, philox_seed, philox_base_ptrs, + sd_mask_base_ptrs, dropout_mask_base_ptrs, + offs_m, offs_n, offs_d, + block_min, block_max, alibi_slope, + descale_q, descale_k, descale_v, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + RETURN_SCORES: tl.constexpr, ACCUMULATOR_TYPE): + if USE_EXP2: + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # get ptrs + k_ptrs = k_base_ptrs + start_n * stride_kn + v_ptrs = v_base_ptrs + start_n * stride_vk + + kv_offs_n = start_n + tl.arange(0, BLOCK_N) + if PADDED_HEAD: + k_mask, k_mask_other = (offs_d[:, None] < ACTUAL_BLOCK_DMODEL), 0.0 + v_mask, v_mask_other = (offs_d[None, :] < ACTUAL_BLOCK_DMODEL), 0.0 + + # load k and if preload_v then v + k = tl.load(k_ptrs, mask=k_mask, other=k_mask_other) if PADDED_HEAD else tl.load(k_ptrs) + if PRE_LOAD_V: + v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD else tl.load(v_ptrs) + + # setup qk accumlator + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) + + # -- compute qk ---- + if IS_FP8 : + qk += (tl.dot(q, k) * descale_q * descale_k) + else: + qk += tl.dot(q, k) + qk_scaled = qk * SM_SCALE + + if USE_ALIBI: + # compute the global position of each token within the sequence + q_offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + alibi_block = compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, q_offs_m, + kv_offs_n) + qk_scaled += alibi_block + + # compute qk mask + qk_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + + # compute bias + if bias_base_ptrs is not None: + bias_ptrs = bias_base_ptrs + start_n * stride_bn + bias = tl.load(bias_ptrs, mask=qk_mask, other=0.0) + qk_scaled += bias + + # get max scores so far + m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) + + # scale and subtract max + q_shifted = tl.where(m_ij[:, None] == float("-inf"), + float("-inf"), + qk_scaled - m_ij[:, None]) + + # Compute scaled QK and softmax probabilities + if USE_EXP2: + p = tl.math.exp2(q_shifted * RCP_LN2) + else: + p = tl.math.exp(q_shifted) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + dropout_mask_ptrs = dropout_mask_base_ptrs + start_n * stride_sn + sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn + philox_ptrs = philox_base_ptrs + start_n * stride_sn + if tl_DROPOUT_USE_PYTORCH: + dropout_mask = tl.load(dropout_mask_ptrs, mask=qk_mask) + else: + rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance + dropout_mask = rng_output > dropout_p + if tl_DROPOUT_DUMP: + tl.store(dropout_mask_ptrs, dropout_mask, mask=qk_mask) + + # return scores with negative values for dropped vals + sd_mask = tl.where(dropout_mask, p, -p) + tl.store(sd_mask_ptrs, sd_mask, mask=qk_mask) + + # apply dropout mask in place + p = tl.where(dropout_mask, p, 0.0) + elif RETURN_SCORES: + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn + tl.store(sd_mask_ptrs, p, mask=qk_mask) + + # -- update output accumulator -- + # alpha is an adjustment factor for acc and li as we loop and find new maxes + # store the diff in maxes to adjust acc and li as we discover new maxes + m_diff = tl.where(m_ij == float("-inf"), + float("-inf"), + m_i - m_ij) + if USE_EXP2: + alpha = tl.math.exp2(m_diff * RCP_LN2) + else: + alpha = tl.math.exp(m_diff) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD else tl.load(v_ptrs) + + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + if IS_FP8: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) + else: + acc += tl.dot(p.to(v.type.element_ty), v) + + return acc, l_i, m_i @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, - actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, +def _attn_fwd_mask(acc, l_i, m_i, + q, k_base_ptrs, v_base_ptrs, bias_base_ptrs, + stride_kn, stride_vk, stride_bn, stride_sn, start_m, + seqlen_k, seqlen_q, + dropout_p, philox_seed, philox_base_ptrs, + sd_mask_base_ptrs, dropout_mask_base_ptrs, + offs_m, offs_n, offs_d, + block_min, block_max, n_extra_tokens, alibi_slope, descale_q, descale_k, descale_v, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, + PRE_LOAD_V: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, - RETURN_SCORES: tl.constexpr, ACCUMULATOR_TYPE): + RETURN_SCORES: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, + ACCUMULATOR_TYPE): if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 + + # seqlen diff + seqlen_delta_qk = seqlen_k - seqlen_q # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): + # get ptrs + k_ptrs = k_base_ptrs + start_n * stride_kn + v_ptrs = v_base_ptrs + start_n * stride_vk + # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - if MASK_STEPS: - k_offs_n = start_n + tl.arange(0, BLOCK_N) - else: - k_offs_n = None - k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) - k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) + kv_offs_n = start_n + tl.arange(0, BLOCK_N) + k_mask = (kv_offs_n[None, :] < seqlen_k) + v_mask = (kv_offs_n[:, None] < seqlen_k) + if PADDED_HEAD: + k_mask = k_mask & (offs_d[:, None] < ACTUAL_BLOCK_DMODEL) + v_mask = v_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + + # load k and if preload_v then v + k = tl.load(k_ptrs, mask=k_mask, other = 0.0) if PRE_LOAD_V: - # We can use the same offsets as k, just with dims transposed. - v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + + # setup qk accumlator qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) + # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. - # last step might get wasted but that is okay. check if this masking works For - # that case. - if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) - size_n = start_n + OFFS_N[None, :] - mask = size_n < boundary_m[:, None] - qk = tl.where(mask, qk, float("-inf")) - - # compute masks - q_mask = (OFFS_M[:, None] < actual_seqlen_q) - k_mask = ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) - p_mask = q_mask & k_mask + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (n_extra_tokens != 0) and (start_n + BLOCK_N == block_max): + boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32) + size_n = start_n + offs_n[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) # -- compute qk ---- if IS_FP8 : @@ -84,27 +207,121 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri if USE_ALIBI: # compute the global position of each token within the sequence - global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - global_n_positions = start_n + tl.arange(0, BLOCK_N) - alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, - global_n_positions) + q_offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + alibi_block = compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, q_offs_m, + kv_offs_n) qk_scaled += alibi_block - if IS_CAUSAL: - causal_boundary = start_n + offs_n_causal - causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) + if USE_SLIDING_WINDOW: + if IS_CAUSAL: + # ========== CAUSAL SLIDING WINDOW MASKING ========== + # For causal sliding window, we need to apply both constraints: + # 1. Causal: col_idx <= row_idx + (seqlen_k - seqlen_q) + # 2. Sliding window: row_idx - window_left <= col_idx <= row_idx + window_right + + # Get positions + row_idx = offs_m # Query positions + col_idx = kv_offs_n # Key positions + + # Expand for broadcasting + row_idx_expanded = row_idx[:, None] # [BLOCK_M, 1] + col_idx_expanded = col_idx[None, :] # [1, BLOCK_N] + + # Apply causal constraint: can only attend to positions before or at the diagonal + causal_offset = seqlen_k - seqlen_q + causal_mask = col_idx_expanded > (row_idx_expanded + causal_offset) + + # Apply sliding window constraint + if WINDOW_SIZE_LEFT < 0: + # Only right window constraint + window_mask = col_idx_expanded > (row_idx_expanded + causal_offset + WINDOW_SIZE_RIGHT) + else: + # Both left and right window constraints + # Adjust window bounds by causal offset + left_bound = row_idx_expanded + causal_offset - WINDOW_SIZE_LEFT + right_bound = row_idx_expanded + causal_offset + WINDOW_SIZE_RIGHT + + # Can't attend to positions outside the window + window_mask = (col_idx_expanded < left_bound) | (col_idx_expanded > right_bound) + + # Final mask is the union of both constraints (True = cannot attend) + mask = causal_mask | window_mask + + # Apply mask + qk_scaled = tl.where(mask, float("-inf"), qk_scaled) + else: + # ========== NON-CAUSAL SLIDING WINDOW MASKING ========== + # Exactly matching reference construct_local_mask: + # row_idx = query positions, col_idx = key positions + # sk = seqlen_k, sq = seqlen_q + + # Get positions + row_idx = offs_m # Query positions + col_idx = kv_offs_n # Key positions + + # sk and sq from reference (no padding masks in this test) + sk = seqlen_k + sq = seqlen_q + + # Expand for broadcasting + row_idx_expanded = row_idx[:, None] # [BLOCK_M, 1] + col_idx_expanded = col_idx[None, :] # [1, BLOCK_N] + + # Reference logic for mask computation + if WINDOW_SIZE_LEFT < 0: + # Reference: return col_idx > row_idx + sk - sq + window_size[1] + mask = col_idx_expanded > (row_idx_expanded + sk - sq + WINDOW_SIZE_RIGHT) + else: + # Reference: + # sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + # return torch.logical_or( + # col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + # col_idx < row_idx + sk - sq - window_size[0], + # ) + # Create sk tensor with proper shape for broadcasting + # sk represents the key sequence length, which should be compared per column + sk_full = tl.full((1, BLOCK_N), sk, dtype=tl.int32) + + # Compute boundaries + right_bound_val = row_idx_expanded + sk - sq + WINDOW_SIZE_RIGHT + right_bound = tl.minimum(right_bound_val, sk_full) + left_bound = row_idx_expanded + sk - sq - WINDOW_SIZE_LEFT + + # Mask where True = cannot attend (matching reference) + mask = (col_idx_expanded > right_bound) | (col_idx_expanded < left_bound) + + # Apply mask (set to -inf where mask is True) + qk_scaled = tl.where(mask, float("-inf"), qk_scaled) + else: + if IS_CAUSAL: + causal_boundary = start_n + offs_n - seqlen_delta_qk + causal_mask = offs_m[:, None] >= causal_boundary[None, :] + qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) - if bias_ptrs is not None: - bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None - bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) + # compute qk mask + qk_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + + # compute bias + if bias_base_ptrs is not None: + bias_ptrs = bias_base_ptrs + start_n * stride_bn + bias = tl.load(bias_ptrs, mask=qk_mask, other=0.0) qk_scaled += bias # get max scores so far m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) # scale and subtract max - q_shifted = qk_scaled - m_ij[:, None] + # IMPORTANT: Handle the case where all values are -inf + # When m_ij = -inf and qk_scaled = -inf, subtraction gives NaN + # We need to handle this explicitly + if USE_SLIDING_WINDOW: + # Check if this block has any valid values (m_ij != -inf) + # For rows where everything is -inf, set q_shifted to -inf (not NaN) + q_shifted = tl.where(m_ij[:, None] == float("-inf"), + float("-inf"), + qk_scaled - m_ij[:, None]) + else: + q_shifted = qk_scaled - m_ij[:, None] # Compute scaled QK and softmax probabilities if USE_EXP2: @@ -115,38 +332,44 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: + dropout_mask_ptrs = dropout_mask_base_ptrs + start_n * stride_sn + sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn + philox_ptrs = philox_base_ptrs + start_n * stride_sn if tl_DROPOUT_USE_PYTORCH: - dropout_mask = tl.load(dropout_mask_ptrs, mask=p_mask) + dropout_mask = tl.load(dropout_mask_ptrs, mask=qk_mask) else: rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance dropout_mask = rng_output > dropout_p if tl_DROPOUT_DUMP: - tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask) + tl.store(dropout_mask_ptrs, dropout_mask, mask=qk_mask) # return scores with negative values for dropped vals sd_mask = tl.where(dropout_mask, p, -p) - tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) + tl.store(sd_mask_ptrs, sd_mask, mask=qk_mask) # apply dropout mask in place p = tl.where(dropout_mask, p, 0.0) elif RETURN_SCORES: # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - tl.store(sd_mask_ptrs, p, mask=p_mask) + sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn + tl.store(sd_mask_ptrs, p, mask=qk_mask) # -- update output accumulator -- # alpha is an adjustment factor for acc and li as we loop and find new maxes # store the diff in maxes to adjust acc and li as we discover new maxes - m_diff = m_i - m_ij + m_diff = tl.where(m_ij == float("-inf"), + float("-inf"), + m_i - m_ij) if USE_EXP2: alpha = tl.math.exp2(m_diff * RCP_LN2) else: alpha = tl.math.exp(m_diff) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + # -- update m_i and l_i l_i = l_i * alpha + l_ij - # update m_i and l_i m_i = m_ij if IS_FP8: @@ -154,108 +377,140 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) else: acc += tl.dot(p.to(v.type.element_ty), v) - - k_ptrs += BLOCK_N * stride_kn - v_ptrs += BLOCK_N * stride_vk - if bias_ptrs is not None: - bias_ptrs += BLOCK_N * stride_bn - if RETURN_SCORES: - sd_mask_ptrs += BLOCK_N * stride_sn - - if ENABLE_DROPOUT: - dropout_mask_ptrs += BLOCK_N * stride_sn - philox_ptrs += BLOCK_N * stride_sn + return acc, l_i, m_i -def get_cdna_autotune_configs(): - return [ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] - - -def get_rdna_autotune_configs(): - return [ - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] - - -def get_autotune_configs(): - if AUTOTUNE: - if is_rdna(): - return get_rdna_autotune_configs() - elif is_cdna(): - return get_cdna_autotune_configs() +@triton.jit +def compute_masking(seqlen_k, seqlen_q, start_m, + IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + """ + Classify K blocks for attention computation with sliding window support. + + Returns: + - n_front_skip_blocks: Blocks completely before the window + - n_front_masked_blocks: Blocks partially overlapping window front + - n_full_blocks: Blocks completely inside the window + - n_back_masked_blocks: Blocks partially overlapping window back + - n_extra_tokens: Padding tokens in last K block + """ + # Example case + # BLOCK_M = 4, BLOCK_N = 4, seqlen_q = 8, seqlen_k = 10 + + # Total K blocks in the key sequence + total_k_blocks = tl.cdiv(seqlen_k, BLOCK_N) + + # check if we will need to do masking due either BLOCK_N being bigger than seqlen_k or seqlen_k not being a factor of BLOCK_N + # n_extra_tokens = 10 % 4 = 2 + # This means the last K block has 2 valid tokens and 2 padding positions + # K blocks visualization: + # Block 0 Block 1 Block 2 (last) + # K0 K1 K2 K3 K4 K5 K6 K7 K8 K9 ?? ?? + # ↑---------↑ ↑---------↑ ↑---↑ ↑---↑ + # full block full block valid pad + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + else: + n_extra_tokens = 0 + + if USE_SLIDING_WINDOW: + # TODO: Optimize by computing which blocks can be fully skipped + # For now, process all blocks with the mask function + if IS_CAUSAL: + return 0, 0, 0, total_k_blocks, n_extra_tokens else: - raise ValueError("Unknown Device Type") + return 0, 0, 0, total_k_blocks, n_extra_tokens else: - arch = get_arch() - if arch == "gfx950": - default_config = triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) - elif arch == "gfx942" and False: # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. - default_config = triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) + if IS_CAUSAL: + # ========== CAUSAL MODE: Classify K Blocks ========== + # Calculate causal boundary for this Q block + # [K0 K1 K2 K3] [K4 K5 K6 K7] [K8 K9 ?? ??] + # Q0-Q3: [ 1 0 0 0] [ 0 0 0 0] [ 0 0 -- --] ← Q0 + # [ 1 1 0 0] [ 0 0 0 0] [ 0 0 -- --] ← Q1 + # [ 1 1 1 0] [ 0 0 0 0] [ 0 0 -- --] ← Q2 + # [ 1 1 1 1] [ 1 1 0 0] [ 0 0 -- --] ← Q3 + # ↑ can see up to K5 + # + # Q4-Q7: [ 1 1 1 1] [ 1 1 1 0] [ 0 0 -- --] ← Q4 + # [ 1 1 1 1] [ 1 1 1 1] [ 0 0 -- --] ← Q5 + # [ 1 1 1 1] [ 1 1 1 1] [ 1 0 -- --] ← Q6 + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -- --] ← Q7 + + # ------------------------------------------------------------ + # 1. figure out, in tokens, the right-most K position + # this Q-block may attend to + # ------------------------------------------------------------ + q_start = start_m * BLOCK_M + q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) + + # causal diagonal offset between the two streams + diag = seqlen_k - seqlen_q # 0 when |Q| == |K| + k_max_token = q_end + diag # last visible K index + + # this Q-block is entirely above the diagonal ⇒ nothing to do + if k_max_token < 0: + return 0, 0, 0, 0, n_extra_tokens + + k_max_token = tl.minimum(k_max_token, seqlen_k - 1) + + # ------------------------------------------------------------ + # 2. translate token indices into K-block indices + # ------------------------------------------------------------ + last_visible_k_block = k_max_token // BLOCK_N + n_visible_k_blocks = tl.minimum(last_visible_k_block + 1, total_k_blocks) + + # ------------------------------------------------------------ + # 3. classify those visible blocks + # – we *never* skip or mask blocks in front, because causal + # attention always starts at K0 + # – the back side can require several masked blocks: + # • intersection of the causal diagonal with K-grid + # (at most ⌈BLOCK_M / BLOCK_N⌉ blocks) + # • plus one extra block if this Q-block stops in the + # middle of a K-block or the last K-block is padded + # ------------------------------------------------------------ + padded_last_k = n_extra_tokens != 0 + is_modulo_mn = (not padded_last_k) & (seqlen_q % BLOCK_M == 0) + + n_back_masked_blocks = BLOCK_M // BLOCK_N + tl.where(is_modulo_mn, 0, 1) + n_back_masked_blocks = tl.minimum(n_back_masked_blocks, n_visible_k_blocks) + + n_front_skip_blocks = 0 # causal never skips the left side + n_front_masked_blocks = 0 # ditto + n_full_blocks = n_visible_k_blocks - n_back_masked_blocks else: - default_config = triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) - - return [ - default_config - ], [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_BLOCK_DMODEL", - "IS_VARLEN", - "HQ", - "HK", - ] - - -autotune_configs, autotune_keys = get_autotune_configs() + # ========== NON-CAUSAL MODE ========== + # Without causal mask, all positions can attend to all positions + # Only need to handle the padding in the last block + # [K0 K1 K2 K3] [K4 K5 K6 K7] [K8 K9 ?? ??] + # Q0-Q3: [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # + # Q4-Q7: [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + + n_front_skip_blocks = 0 # never skips the left side + n_front_masked_blocks = 0 # ditto + if n_extra_tokens != 0: + n_back_masked_blocks = 1 # Last block needs padding mask + n_full_blocks = total_k_blocks - 1 + else: + n_back_masked_blocks = 0 # All blocks are aligned + n_full_blocks = total_k_blocks + + return n_front_skip_blocks, n_front_masked_blocks, n_full_blocks, n_back_masked_blocks, n_extra_tokens @triton.autotune( - configs=autotune_configs, - key=autotune_keys, + configs=fwd_prefill_autotune_configs, + key=fwd_prefill_autotune_keys, use_cuda_graph=True, ) @triton.jit @@ -267,7 +522,8 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_INFERENCE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr, FLIP_GRID: tl.constexpr): @@ -276,22 +532,21 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, # compute offsets if FLIP_GRID: - #NUM_XCDS: tl.constexpr = 8 off_z = tl.program_id(0) off_h_q = tl.program_id(1) start_m = tl.program_id(2) - - #start_m = (tl.cdiv(MAX_SEQLENS_Q, BLOCK_M) - 1) - start_m - - # Remap heads to the same XCD - #pids_per_xcd = HQ // NUM_XCDS - #xcd_group = off_h_q % NUM_XCDS - #pid_in_xcd = off_h_q // NUM_XCDS - #off_h_q = xcd_group * pids_per_xcd + pid_in_xcd else: start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + # Determine if we need to mask the heads + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) @@ -309,76 +564,57 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start - elif IS_INFERENCE: - cu_seqlens_q_start = 0 - cu_seqlens_k_start = 0 - seqlen_q = MAX_SEQLENS_Q - seqlen_k = tl.load(Cache_seqlens + off_z) else: cu_seqlens_q_start = 0 cu_seqlens_k_start = 0 seqlen_q = MAX_SEQLENS_Q seqlen_k = MAX_SEQLENS_K - # Now we compute whether we need to exit early due to causal masking. - # This is because for seqlen_q > seqlen_k, M rows of the attn scores - # are completely masked, resulting in 0s written to the output, and - # inf written to LSE. We don't need to do any GEMMs in this case. - # This block of code determines what N is, and if this WG is operating - # on those M rows. - n_blocks = tl.cdiv(seqlen_k, BLOCK_N) - if (IS_CAUSAL): - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which means - # the causal mask boundary is bottom right aligned, and ends at either - # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn matrix - n_blocks_seqlen = tl.cdiv((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is part of - # the blocks that are all 0. We exit early. - if n_blocks <= 0: - o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) - o_ptrs_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) - # We still need to write 0s to the result - tl.store(o_ptrs, acc, mask=o_ptrs_mask) - # The tensor allocated for L is based on MAX_SEQLENS_Q as that is - # statically known. - l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m - l_ptrs = l_offset + offs_m * stride_lse_m - - l = tl.full([BLOCK_M], value=0.0, dtype=ACCUMULATOR_TYPE) - - # mask_m_offsets = start_m + tl.arange(0, BLOCK_M) - # lse_mask = mask_m_offsets < causal_start_idx - # softmax_lse = tl.where(lse_mask, 0.0, softmax_lse) - l_ptrs_mask = offs_m < MAX_SEQLENS_Q - tl.store(l_ptrs, l, mask=l_ptrs_mask) - # TODO: Should dropout and return encoded softmax be handled here too? - return - - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE: tl.constexpr = HQ // HK - if GROUP_SIZE != 1: - off_h_k = off_h_q // GROUP_SIZE + # Load scale factors if IS_FP8. + if IS_FP8: + descale_q = tl.load(Descale_Q + off_z * stride_descale_q_z + off_h_q) + descale_k = tl.load(Descale_K + off_z * stride_descale_k_z + off_h_k) + descale_v = tl.load(Descale_V + off_z * stride_descale_v_z + off_h_k) else: - off_h_k = off_h_q - - n_extra_tokens = 0 - # print("n_extra_tokens:", n_extra_tokens) - # print("seqlen_k:", seqlen_k) - # print("BLOCK_N:", BLOCK_N) - # return - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N - seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + # figure out masking pattern + n_front_skip_blocks, n_front_masked_blocks, n_full_blocks, n_back_masked_blocks, n_extra_tokens = compute_masking( + seqlen_k, seqlen_q, start_m, IS_CAUSAL, USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, BLOCK_M, BLOCK_N + ) + + # ============================================================ + # PROGRAM EARLY EXIT (All K Blocks Skipped) + # ============================================================ + total_visible_blocks = n_front_masked_blocks + n_full_blocks + n_back_masked_blocks + if total_visible_blocks == 0: + """ + No K blocks visible - write zeros and exit. + """ + # Write zeros to output + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_mask = (offs_m[:, None] < seqlen_q) + if PADDED_HEAD: + o_mask = o_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty), mask=o_mask) + + # Write zeros to LSE + l_ptrs = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m + offs_m * stride_lse_m + tl.store(l_ptrs, tl.zeros([BLOCK_M], dtype=tl.float32), mask=offs_m < seqlen_q) + return + + # ============================================================ + # NORMAL PROCESSING (Some K Blocks Visible) + # ============================================================ + """ + This program has visible K blocks to process. + We'll use two calls to handle different block types efficiently. + """ + + # Initialize for processing # Compute pointers for all the tensors used in this kernel. q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk @@ -413,132 +649,206 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, else: dropout_mask_ptrs = None philox_ptrs = 0 + # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=ACCUMULATOR_TYPE) l_i = tl.full([BLOCK_M], 1.0, dtype=ACCUMULATOR_TYPE) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=ACCUMULATOR_TYPE) + # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q if PADDED_HEAD: q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) - # Load scale factors if IS_FP8. - if IS_FP8: - descale_q = tl.load(Descale_Q + off_z * stride_descale_q_z + off_h_q) - descale_k = tl.load(Descale_K + off_z * stride_descale_k_z + off_h_k) - descale_v = tl.load(Descale_V + off_z * stride_descale_v_z + off_h_k) - else: - descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked blocks. - # Additionally there might be one more due to dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. - # In this case we might exceed n_blocks so pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its actual - # value because there is no masking. Similarly we do not need padding. + # ========== Process MASKED K Blocks in the front ========== + # NOTE: we use USE_SLIDING_WINDOW as guard because the compiler will crash other wise. front masking is only for sliding window so that is fine. + if n_front_masked_blocks > 0 and USE_SLIDING_WINDOW: + block_min = n_front_skip_blocks * BLOCK_N + block_max = (n_front_skip_blocks + n_front_masked_blocks) * BLOCK_N + + acc, l_i, m_i = _attn_fwd_mask( + acc, l_i, m_i, + q, k_ptrs, v_ptrs, bias_ptrs, + stride_kn, stride_vk, stride_bn, stride_sn, + start_m, seqlen_k, seqlen_q, + dropout_p, philox_seed, philox_ptrs, + sd_mask_ptrs, dropout_mask_ptrs, + offs_m, offs_n, offs_d, + block_min, # Start of front masked blocks + block_max, # End of front masked blocks + 0, # n_extra_tokens (0 for front blocks, only relevant for last block) + alibi_slope, + descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, + IS_CAUSAL, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + PRE_LOAD_V, + ENABLE_DROPOUT, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL, SM_SCALE, + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, + RETURN_SCORES=RETURN_SCORES, + USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, + ACCUMULATOR_TYPE=ACCUMULATOR_TYPE + ) + + # ========== Process FULL K Blocks (Fast Path) ========== if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, - sd_mask_ptrs, dropout_mask_ptrs, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, block_max, 0, 0, 0, alibi_slope, - descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, - # IS_CAUSAL, .... - False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE) - block_min = block_max - block_max = n_blocks * BLOCK_N - - tl.debug_barrier() - # Remaining blocks, if any, are full / not masked. - if (masked_blocks > 0): - if IS_CAUSAL: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) - else: - offs_n_causal = 0 - k_ptrs += n_full_blocks * BLOCK_N * stride_kn - v_ptrs += n_full_blocks * BLOCK_N * stride_vk - if USE_BIAS: - bias_ptrs += n_full_blocks * BLOCK_N * stride_bn - if RETURN_SCORES: - sd_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn - if ENABLE_DROPOUT: - dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn - philox_ptrs += n_full_blocks * BLOCK_N * stride_sn - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, - sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks, - n_extra_tokens, alibi_slope, descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, - IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE) - # epilogue + block_min = (n_front_skip_blocks + n_front_masked_blocks) * BLOCK_N + block_max = (n_front_skip_blocks + n_front_masked_blocks + n_full_blocks) * BLOCK_N + + acc, l_i, m_i = _attn_fwd_no_mask( + acc, l_i, m_i, + q, k_ptrs, v_ptrs, bias_ptrs, + stride_kn, stride_vk, stride_bn, stride_sn, + start_m, seqlen_k, seqlen_q, + dropout_p, philox_seed, philox_ptrs, + sd_mask_ptrs, dropout_mask_ptrs, + offs_m, offs_n, offs_d, + block_min, # Start of range: 0 + block_max, # End of range: n_full_blocks * BLOCK_N + alibi_slope, + descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + PRE_LOAD_V, + ENABLE_DROPOUT, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL, SM_SCALE, + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, + RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE + ) + + # ========== Process MASKED K Blocks in the back ========== + if n_back_masked_blocks > 0: + block_min = (n_front_skip_blocks + n_front_masked_blocks + n_full_blocks) * BLOCK_N + block_max = (n_front_skip_blocks + n_front_masked_blocks + n_full_blocks + n_back_masked_blocks) * BLOCK_N + + acc, l_i, m_i = _attn_fwd_mask( + acc, l_i, m_i, + q, k_ptrs, v_ptrs, bias_ptrs, + stride_kn, stride_vk, stride_bn, stride_sn, + start_m, seqlen_k, seqlen_q, + dropout_p, philox_seed, philox_ptrs, + sd_mask_ptrs, dropout_mask_ptrs, + offs_m, offs_n, offs_d, + block_min, # Start of range: n_full_blocks * BLOCK_N + block_max, # End of range: n_visible_k_blocks * BLOCK_N + n_extra_tokens, # Padding tokens in last block + alibi_slope, + descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, + IS_CAUSAL, # Use actual causal flag + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + PRE_LOAD_V, + ENABLE_DROPOUT, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL, SM_SCALE, + USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, + RETURN_SCORES=RETURN_SCORES, + USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, + ACCUMULATOR_TYPE=ACCUMULATOR_TYPE + ) + + # ============================================================ + # EPILOGUE + # ============================================================ # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. - l_recip = 1 / l_i[:, None] + # Instead of directly computing 1/l_i which can be inf, + # we check for the invalid case first + if USE_SLIDING_WINDOW: + # For rows where m_i is still -inf, no keys were valid + # Set l_i to 1.0 to avoid division by zero (acc is already 0) + invalid_mask = m_i == float("-inf") + l_i_safe = tl.where(invalid_mask, 1.0, l_i) + l_recip = 1 / l_i_safe[:, None] + else: + # Original code path + l_recip = 1 / l_i[:, None] acc = acc * l_recip if ENABLE_DROPOUT: dropout_scale = 1 / (1 - dropout_p) acc = acc * dropout_scale - # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, - # then we have one block with a row of all NaNs which come from computing - # softmax over a row of all -infs (-inf - inf = NaN). We check for that here - # and store 0s where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - if IS_CAUSAL: - if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] - z = 0.0 - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - # write back LSE(Log Sum Exponents), the log of the normalization constant - l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m - l_ptrs = l_offset + offs_m * stride_lse_m + # compute log-sum-exp if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 LN2: tl.constexpr = 0.6931471824645996 # compute log-sum-exp in base 2 units mi_base2 = m_i * RCP_LN2 - softmax_lse = mi_base2 + tl.math.log2(l_i) + # For invalid rows, log(l_i) would be -inf, but we want LSE to be -inf + # So we handle this case explicitly + if USE_SLIDING_WINDOW: + log_l_i = tl.where(invalid_mask, 0.0, tl.math.log2(l_i)) + softmax_lse = mi_base2 + log_l_i + # Ensure invalid rows have LSE = -inf + softmax_lse = tl.where(invalid_mask, float("-inf"), softmax_lse) + else: + softmax_lse = mi_base2 + tl.math.log2(l_i) # convert back to natural units softmax_lse *= LN2 else: - softmax_lse = m_i + tl.math.log(l_i) + if USE_SLIDING_WINDOW: + log_l_i = tl.where(invalid_mask, 0.0, tl.math.log(l_i)) + softmax_lse = m_i + log_l_i + softmax_lse = tl.where(invalid_mask, float("-inf"), softmax_lse) + else: + softmax_lse = m_i + tl.math.log(l_i) + + # handle masking edge cases + if USE_SLIDING_WINDOW: + if IS_CAUSAL: + pass + else: + pass + else: + if IS_CAUSAL: + # When seqlen_q > seqlen_k, some rows are completely above the causal diagonal + # These rows have all -inf attention scores, resulting in NaN after softmax + # e.g. + # Q length: 6, K length: 4 + # Causal mask (X = can attend, . = cannot): + # K0 K1 K2 K3 + # Q0 . . . . <- All masked, would give NaN + # Q1 . . . . <- All masked, would give NaN + # Q2 X . . . <- First valid row + # Q3 X X . . + # Q4 X X X . + # Q5 X X X X + causal_start_idx = seqlen_q - seqlen_k + start_m_idx = start_m * BLOCK_M + + # Create mask for rows that need zeroing + row_indices = start_m_idx + tl.arange(0, BLOCK_M) + causal_mask = row_indices < causal_start_idx + + # Zero out both acc and LSE for these rows + if causal_start_idx > start_m_idx: + end_m_idx = (start_m + 1) * BLOCK_M + if causal_start_idx < end_m_idx: + # This block contains the boundary - need to mask acc + out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + out_ptrs_mask = row_indices[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + + # Zero out LSE for rows above diagonal + softmax_lse = tl.where(causal_mask, 0.0, softmax_lse) - if IS_CAUSAL: - # zero out nans caused by -infs when doing causal - lse_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx - softmax_lse = tl.where(lse_mask, 0.0, softmax_lse) + # write back LSE(Log Sum Exponents), the log of the normalization constant + l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m + l_ptrs = l_offset + offs_m * stride_lse_m # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. - # This is only true for the last M block. For others, overflow_size will be -ve + # This is only true for the last Q block. For others, overflow_size will be -ve + end_m_idx = (start_m + 1) * BLOCK_M overflow_size = end_m_idx - seqlen_q if overflow_size > 0: boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary - tl.store(l_ptrs, softmax_lse, mask=l_ptrs_mask) # the log of the normalization constant + tl.store(l_ptrs, softmax_lse, mask=l_ptrs_mask) else: - tl.store(l_ptrs, softmax_lse) # the log of the normalization constant + tl.store(l_ptrs, softmax_lse) # write back O o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om @@ -556,7 +866,6 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, else: tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) - def attention_prefill_forward_triton_impl( q: torch.Tensor, k: torch.Tensor, @@ -565,6 +874,8 @@ def attention_prefill_forward_triton_impl( sm_scale: float, alibi_slopes: Optional[torch.Tensor], causal: bool, + window_size_left: int, + window_size_right: int, bias: Optional[torch.Tensor], layout: Literal["bshd", "bhsd", "thd"], # varlen @@ -613,6 +924,7 @@ def attention_prefill_forward_triton_impl( # check flags IS_VARLEN = layout == "thd" + use_sliding_window = window_size_left != -1 or window_size_right!= -1 use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) is_inference = False if cache_seqlens is None else True if is_inference: @@ -629,6 +941,7 @@ def attention_prefill_forward_triton_impl( # shape total_q, nheads_q, head_size = q.shape _, nheads_k, _ = k.shape + assert cu_seqlens_q is not None batch = len(cu_seqlens_q) - 1 # softmax_lse is the log of the normalization constant / sum of expoential score(unnormalzied probablities) @@ -695,7 +1008,7 @@ def attention_prefill_forward_triton_impl( attn_fwd[grid](q, k, v, bias, cache_seqlens, cache_batch_idx, descale_q, descale_k, descale_v, descale_o, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, - sm_scale, softmax_lse, o, + sm_scale, softmax_lse, o, stride_qb, stride_qh, stride_qm, stride_qd, stride_kb, stride_kh, stride_kn, stride_kd, stride_vb, stride_vh, stride_vn, stride_vd, @@ -703,13 +1016,15 @@ def attention_prefill_forward_triton_impl( stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, stride_sz, stride_sh, stride_sm, stride_sn, - stride_lse_z, stride_lse_h, stride_lse_m, + stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, IS_VARLEN=IS_VARLEN, IS_INFERENCE=is_inference, + MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, + USE_SLIDING_WINDOW=use_sliding_window, WINDOW_SIZE_LEFT=window_size_left, WINDOW_SIZE_RIGHT=window_size_right, + IS_VARLEN=IS_VARLEN, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, - USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p - > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT, FLIP_GRID=FLIP_GRID) + USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, + IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT, FLIP_GRID=FLIP_GRID) - return softmax_lse, sd_mask if return_softmax else None + return softmax_lse, sd_mask if return_softmax else None \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py index 6af99798ae9..2265af12096 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/fwd_ref.py @@ -1,12 +1,16 @@ import torch import math -from typing import Literal, Optional +from typing import Literal, Optional, Union from .utils import compute_alibi_tensor_ref DEBUG = False DEBUG_CORE = False -def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2): +def attention_forward_core_ref_impl( + q, k, v, sm_scale, causal, window_size_left, window_size_right, + dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2, + cache_seqlens=None +): if DEBUG_CORE: print() print("attention_forward_core_ref_impl") @@ -15,15 +19,21 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox print("v:", v, v.shape) print("sm_scale:", sm_scale) print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) print("dropout_p:", dropout_p) print("philox_seed:", philox_seed) print("philox_offset:", philox_offset) print("use_exp2:", use_exp2) + print("cache_seqlens:", cache_seqlens) # cast to float32 q = q.to(torch.float32) k = k.to(torch.float32) v = v.to(torch.float32) + + # get seqlens + L_q, L_k = q.shape[1], k.shape[1] # Compute attention scores attention_scores = torch.matmul(q, k.transpose(-2, -1)) @@ -37,47 +47,146 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox # Apply ALiBi if slopes are provided if alibi_slopes is not None: - L_q, L_k = q.shape[1], k.shape[1] - if DEBUG_CORE: - print("alibi_slopes:", alibi_slopes, alibi_slopes.shape) - alibi_bias = compute_alibi_tensor_ref(alibi_slopes, L_q, L_k) - if DEBUG_CORE: - print("alibi_bias:", alibi_bias, alibi_bias.shape) - alibi_bias = alibi_bias.reshape(-1, L_q, L_k) - if DEBUG_CORE: - print("alibi_bias_flat:", alibi_bias, alibi_bias.shape) + if cache_seqlens is not None: + # DECODE MODE: Special ALiBi handling + # In decode mode, k has shape [nheads, max_cache_len, head_dim] + # but only cache_seqlens positions are valid + + # The test's attn_bias_from_alibi_slopes uses this formula: + # relative_pos = torch.abs(row_idx + sk - sq - col_idx) + # where sk = actual valid key length, sq = query length + + row_idx = torch.arange(L_q, device=q.device, dtype=torch.float32).unsqueeze(1) + col_idx = torch.arange(L_k, device=q.device, dtype=torch.float32).unsqueeze(0) + + # Compute relative positions + # cache_seqlens is the actual number of valid keys (sk in the test) + # L_q is the query sequence length (sq in the test) + relative_pos = torch.abs(row_idx + cache_seqlens - L_q - col_idx) + + # Apply slopes + if alibi_slopes.dim() == 1: + # Shape: [nheads] -> [nheads, 1, 1] + alibi_slopes_expanded = alibi_slopes.view(-1, 1, 1) + else: + # Already has batch dimension + alibi_slopes_expanded = alibi_slopes + + alibi_bias = -alibi_slopes_expanded * relative_pos + + if DEBUG_CORE: + print(f"Decode ALiBi: cache_seqlens={cache_seqlens}, L_q={L_q}, L_k={L_k}") + print(f"relative_pos shape: {relative_pos.shape}") + print(f"alibi_bias shape: {alibi_bias.shape}") + else: + if DEBUG_CORE: + print("alibi_slopes:", alibi_slopes, alibi_slopes.shape) + alibi_bias = compute_alibi_tensor_ref(alibi_slopes, L_q, L_k) + if DEBUG_CORE: + print("alibi_bias:", alibi_bias, alibi_bias.shape) + alibi_bias = alibi_bias.reshape(-1, L_q, L_k) + if DEBUG_CORE: + print("alibi_bias_flat:", alibi_bias, alibi_bias.shape) + attention_scaled_scores = attention_scaled_scores + alibi_bias if DEBUG_CORE: print("attention_scaled_scores after alibi:", attention_scaled_scores, attention_scaled_scores.shape) - - # Apply causal mask if necessary - if causal: - L_q, L_k = q.shape[1], k.shape[1] - row_idx = torch.arange(L_q, device=q.device).unsqueeze(1) - col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) - col_offset = L_q-L_k - causal_mask = row_idx >= (col_offset + col_idx) + # Apply masks + row_idx = torch.arange(L_q, device=q.device).unsqueeze(1) + col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) + + if cache_seqlens is not None: + # We're in decode mode with a KV cache + # k and v are full allocated size, but only cache_seqlens positions are valid + + # Create a mask for valid cache positions + cache_mask = col_idx < cache_seqlens + + # Use cache_seqlens for offset calculation to match test's construct_local_mask + # which uses key_padding_mask.sum() as the sequence length + col_offset = cache_seqlens - L_q + + if DEBUG_CORE: + print(f"Cache mode: valid_len={cache_seqlens}, L_k={L_k}") + print(f"Using col_offset={col_offset} based on valid cache length") + else: + # Calculate offset for when seqlen_q != seqlen_k + # This offset aligns query positions to key positions + # When L_q < L_k, offset is positive, meaning query i maps to key position (i + offset) + # This is consistent with construct_local_mask in the tests which uses (sk - sq) + col_offset = L_k - L_q + cache_mask = None + + mask_applied = False + if causal and (window_size_left, window_size_right) == (-1, -1): + # Pure causal: ensure query doesn't attend to future keys + # With offset, query i can attend to keys up to position (i + col_offset) + mask = row_idx >= (col_idx - col_offset) + mask_applied = True if DEBUG_CORE: - print("causal_mask:", causal_mask) - # set -inf to places the causal mask is false + print("causal_mask:", mask) + elif (window_size_left, window_size_right) != (-1, -1): + # Handle the case where window sizes exceed sequence length + if window_size_left >= L_k: + window_size_left = -1 # No left limit + if window_size_right >= L_k: + window_size_right = -1 # No right limit + + if causal: + # Causal + sliding window: ensure we don't attend to future + window_size_right = min(window_size_right, 0) if window_size_right != -1 else 0 + + # Create sliding window mask + # Each query at position i attends to keys in [i + offset - left, i + offset + right] + if window_size_left == -1 and window_size_right == -1: + # No window restriction + mask = torch.ones((L_q, L_k), dtype=torch.bool, device=q.device) + else: + mask = torch.ones((L_q, L_k), dtype=torch.bool, device=q.device) + if window_size_left != -1: + # Each query at position i attends to keys from position (i - left) accounting for offset + mask = mask & (col_idx >= (row_idx + col_offset - window_size_left)) + if window_size_right != -1: + # Each query at position i attends to keys up to position (i + right) accounting for offset + mask = mask & (col_idx <= (row_idx + col_offset + window_size_right)) + + # Apply causal constraint + if causal: + causal_mask = row_idx >= (col_idx - col_offset) + mask = mask & causal_mask + + mask_applied = True + if DEBUG_CORE: + print(f"sliding_window_mask (left={window_size_left}, right={window_size_right}):", mask) + + # Apply cache mask if needed + if cache_mask is not None: + if mask_applied: + mask = mask & cache_mask + else: + mask = cache_mask + mask_applied = True + + # Apply the mask if created + if mask_applied: attention_scaled_scores = attention_scaled_scores.masked_fill( - torch.logical_not(causal_mask.unsqueeze(0)), float('-inf') + torch.logical_not(mask.unsqueeze(0)), float('-inf') ) if DEBUG_CORE: - print("attention_scaled_scores after causal:", attention_scaled_scores, attention_scaled_scores.shape) + print("attention_scaled_scores after masking:", attention_scaled_scores, attention_scaled_scores.shape) # Compute max for numerical stability max_scores = torch.max(attention_scaled_scores, dim=-1, keepdim=True)[0] if DEBUG_CORE: print("max_scores:", max_scores, max_scores.shape) - if causal: + if mask_applied: # Replace -inf in max_scores with zeros to avoid NaN in subtraction max_scores = torch.where( torch.isinf(max_scores), torch.zeros_like(max_scores), max_scores ) - if DEBUG: - print("max_scores if causal:", max_scores, max_scores.shape) + if DEBUG_CORE: + print("max_scores after mask handling:", max_scores, max_scores.shape) # Shift scores attention_shifted_scaled_scores = attention_scaled_scores - max_scores @@ -98,7 +207,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox sum_exp_scores = torch.sum(exp_scores, dim=-1, keepdim=True) if DEBUG_CORE: print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape) - if causal: + if mask_applied: # if sum of exp scores is 0.0 it means scores where -inf, we cannot compute softmax and softmax_lse. Setting to 1 deals with -inf case cleanly sum_exp_scores = torch.where( sum_exp_scores == 0, @@ -158,7 +267,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox return o, softmax_lse, sd_mask -def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2): +def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, window_size_left, window_size_right, layout, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2): """Compute reference output and softmax_lse using PyTorch's built-in function""" # Ensure the layout is 'bhsd' @@ -194,7 +303,7 @@ def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout # Call the core attention function o, softmax_lse, sd_mask = attention_forward_core_ref_impl( - q, k, v, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2 + q, k, v, sm_scale, causal, window_size_left, window_size_right, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2 ) if group_size != 1: @@ -224,6 +333,8 @@ def attention_varlen_forward_pytorch_ref_impl( v, sm_scale, causal, + window_size_left, + window_size_right, layout, cu_seqlens_q, cu_seqlens_k, @@ -302,7 +413,7 @@ def attention_varlen_forward_pytorch_ref_impl( alibi_slopes_i = None # Call the core attention function for this sequence - o_i, softmax_lse_i, sd_mask_i = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes_i, use_exp2) + o_i, softmax_lse_i, sd_mask_i = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, window_size_left, window_size_right, dropout_p, philox_seed, philox_offset, alibi_slopes_i, use_exp2) # Reshape outputs back to original dimensions if group_size != 1: @@ -328,8 +439,6 @@ def attention_varlen_forward_pytorch_ref_impl( return o, softmax_lse, sd_mask - - def attention_forward_pytorch_ref_impl( q: torch.Tensor, k: torch.Tensor, @@ -338,6 +447,8 @@ def attention_forward_pytorch_ref_impl( sm_scale: float, alibi_slopes: Optional[torch.Tensor], causal: bool, + window_size_left: int, + window_size_right: int, layout: Literal["bshd", "bhsd", "thd"], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, @@ -356,6 +467,8 @@ def attention_forward_pytorch_ref_impl( v.clone(), sm_scale, causal, + window_size_left, + window_size_right, layout, cu_seqlens_q, cu_seqlens_k, @@ -374,6 +487,8 @@ def attention_forward_pytorch_ref_impl( v.clone(), sm_scale, causal, + window_size_left, + window_size_right, layout, dropout_p, philox_seed, @@ -385,3 +500,152 @@ def attention_forward_pytorch_ref_impl( out.copy_(o_ref.to(out.dtype)) return softmax_lse_ref, sd_mask_ref + +def attention_decode_forward_ref_impl( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + out: torch.Tensor, + sm_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + alibi_slopes: Optional[torch.Tensor], + layout: Literal["bshd"], + cache_seqlens: Optional[torch.Tensor], + cache_batch_idx: Optional[torch.Tensor], +): + """Compute reference output for decode attention using PyTorch's built-in functions""" + + # get batch size before any layout conversion + batch_size = q.shape[0] + + # handle cache_batch_idx + if cache_batch_idx is not None: + # remap batch indices for cache access + batch_indices = cache_batch_idx + else: + batch_indices = torch.arange(batch_size, device=q.device) + + # copy new keys and values into cache if provided (before any layout conversion) + if k_new is not None and v_new is not None: + _, seq_len_new, _, _ = k_new.shape # shape is [batch, seq_len, nheads, head_dim] for bshd layout + + for b in range(batch_size): + cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices + + # determine where to place new k/v in cache + if cache_seqlens is not None: + if torch.is_tensor(cache_seqlens): + start_pos = cache_seqlens[b].item() + else: + start_pos = cache_seqlens + else: + # if no cache_seqlens, assume we're filling from the beginning + start_pos = 0 + + end_pos = start_pos + seq_len_new + + # copy new keys and values into cache (both are in bshd layout) + k_cache[cache_idx, start_pos:end_pos, :, :] = k_new[b, :, :, :] + v_cache[cache_idx, start_pos:end_pos, :, :] = v_new[b, :, :, :] + + # ensure the layout is 'bhsd' + if layout == "bshd": + q = q.transpose(1, 2).contiguous() + k_cache = k_cache.transpose(1, 2).contiguous() + v_cache = v_cache.transpose(1, 2).contiguous() + elif layout != "bhsd": + raise ValueError(f"Unknown layout {layout}") + + # prepare tensors + batch_size_q, nheads_q, seq_len_q, head_dim = q.shape + batch_size_cache, nheads_k, max_cache_len, head_dim_k = k_cache.shape + _, nheads_v, _, head_dim_v = v_cache.shape + + # validate dimensions + assert head_dim == head_dim_k == head_dim_v, f"Head dimensions must match: {head_dim}, {head_dim_k}, {head_dim_v}" + + # handle MQA/GQA + group_size = nheads_q // nheads_k + if nheads_q % nheads_k != 0: + raise ValueError("nheads_q must be divisible by nheads_k") + + # handle cache_batch_idx + if cache_batch_idx is not None: + # remap batch indices for cache access + batch_indices = cache_batch_idx + else: + batch_indices = torch.arange(batch_size, device=q.device) + + # prepare outputs + o = torch.zeros_like(q) + softmax_lse = torch.zeros((batch_size, nheads_q, seq_len_q), dtype=torch.float32, device=q.device) + + # process each batch element + for b in range(batch_size): + cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices + + # determine valid cache length for this batch element + if cache_seqlens is not None: + if torch.is_tensor(cache_seqlens): + cache_len = cache_seqlens[b].item() + if k_new is not None: + _, seq_len_new, _, _ = k_new.shape + cache_len += seq_len_new + else: + cache_len = cache_seqlens + if k_new is not None: + _, seq_len_new, _, _ = k_new.shape + cache_len += seq_len_new + else: + cache_len = max_cache_len + + # CHANGE: Extract the full cache, not just valid portion + # This matches what the test does - it uses full k_cache_rep/v_cache_rep + k_b = k_cache[cache_idx, :, :, :] # [nheads_k, max_cache_len, head_dim] + v_b = v_cache[cache_idx, :, :, :] # [nheads_v, max_cache_len, head_dim] + q_b = q[b:b+1, :, :, :] # [1, nheads_q, seq_len_q, head_dim] + + # handle MQA/GQA by expanding k and v + if group_size != 1: + # expand k and v to match q's number of heads + k_b = k_b.unsqueeze(1).expand(-1, group_size, -1, -1) + k_b = k_b.reshape(nheads_q, max_cache_len, head_dim) + + v_b = v_b.unsqueeze(1).expand(-1, group_size, -1, -1) + v_b = v_b.reshape(nheads_q, max_cache_len, head_dim) + + # reshape for attention_forward_core_ref_impl + q_b = q_b.reshape(nheads_q, seq_len_q, head_dim) + + # handle alibi slopes for this batch + alibi_slopes_b = None + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + alibi_slopes_b = alibi_slopes[b] + else: + alibi_slopes_b = alibi_slopes + + # call core attention function with cache information + o_b, softmax_lse_b, _ = attention_forward_core_ref_impl( + q_b, k_b, v_b, sm_scale, causal, window_size_left, window_size_right, + dropout_p=0.0, philox_seed=None, philox_offset=None, + alibi_slopes=alibi_slopes_b, use_exp2=True, + cache_seqlens=cache_len, # Pass valid cache length + ) + + # store outputs + o[b, :, :, :] = o_b.reshape(nheads_q, seq_len_q, head_dim) + softmax_lse[b, :, :] = softmax_lse_b.reshape(nheads_q, seq_len_q) + + # restore original layout if necessary + if layout == "bshd": + o = o.transpose(1, 2) + + # copy output to the provided tensor + out.copy_(o.to(out.dtype)) + + return softmax_lse \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 9c10b7436c2..4e9bd006aee 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -5,7 +5,7 @@ from .bwd_prefill_fused_atomics import attention_prefill_backward_triton_fused_atomics_impl from .bwd_prefill_fused_no_atomics import attention_prefill_backward_triton_split_fused_no_atomics_impl from .fwd_decode import attention_decode_forward_triton_impl -from .fwd_ref import attention_forward_pytorch_ref_impl +from .fwd_ref import attention_forward_pytorch_ref_impl, attention_decode_forward_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl from .utils import DEBUG, USE_REF, MetaData, is_fp8 from einops import rearrange, repeat @@ -101,6 +101,8 @@ def fwd(q: torch.Tensor, metadata.sm_scale, metadata.alibi_slopes, metadata.causal, + window_size_left, + window_size_right, metadata.layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, @@ -123,6 +125,8 @@ def fwd(q: torch.Tensor, metadata.sm_scale, metadata.alibi_slopes, metadata.causal, + window_size_left, + window_size_right, None, metadata.layout, metadata.cu_seqlens_q, @@ -253,6 +257,8 @@ def bwd( softmax_scale, alibi_slopes, causal, + window_size_left, + window_size_right, "bshd", None, None, @@ -472,6 +478,8 @@ def varlen_fwd( metadata.sm_scale, metadata.alibi_slopes, metadata.causal, + window_size_left, + window_size_right, metadata.layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, @@ -494,6 +502,8 @@ def varlen_fwd( metadata.sm_scale, metadata.alibi_slopes, metadata.causal, + window_size_left, + window_size_right, None, metadata.layout, metadata.cu_seqlens_q, @@ -626,6 +636,8 @@ def varlen_bwd( softmax_scale, alibi_slopes, causal, + window_size_left, + window_size_right, "thd", cu_seqlens_q, cu_seqlens_k, @@ -801,8 +813,21 @@ def fwd_kvcache( metadata.layout = "bshd" metadata.max_seqlens_q = q.shape[1] metadata.max_seqlens_k = k_cache.shape[1] - metadata.cache_seqlens = cache_seqlens metadata.cache_batch_idx = cache_batch_idx + if isinstance(cache_seqlens, int): + metadata.cache_seqlens = torch.tensor(cache_seqlens, device=q.device) + else: + metadata.cache_seqlens = cache_seqlens + + # window_size can be a tensor sometimes + if isinstance(window_size_left, torch.Tensor): + metadata.window_size_left = int(window_size_left.item()) + else: + metadata.window_size_left = window_size_left + if isinstance(window_size_right, torch.Tensor): + metadata.window_size_right = int(window_size_right.item()) + else: + metadata.window_size_right = window_size_right k_new = k v_new = v @@ -829,7 +854,7 @@ def fwd_kvcache( # Rotary Embedding Implementation if apply_rotary: - if metadata.causal: # NOTE: when support is added. Add `or metadata.local` + if metadata.causal or (window_size_left != -1 or window_size_right !=-1): # NOTE: when support is added. Add `or metadata.local` q_ro = apply_rotary_emb( q, metadata.rotary_cos, @@ -860,8 +885,29 @@ def fwd_kvcache( q, k_new = q_ro.to(q.dtype), k_ro.to(q.dtype) # launch kernel - DECODE_KERNEL= True # os.environ.get('DECODE_KERNEL', '0').lower() in ('1', 'true', 'yes') - if DECODE_KERNEL: + if USE_REF: + if DEBUG: + print("Using reference implementation") + softmax_lse_ref = attention_decode_forward_ref_impl( + q, + k_cache, + v_cache, + k_new, + v_new, + out, + metadata.sm_scale, + metadata.causal, + metadata.window_size_left, + metadata.window_size_right, + metadata.alibi_slopes, + metadata.layout, + metadata.cache_seqlens, + metadata.cache_batch_idx, + ) + softmax_lse=softmax_lse_ref + else: + if DEBUG: + print("Using Triton implementation") softmax_lse_triton = attention_decode_forward_triton_impl( q, k_cache, @@ -871,38 +917,14 @@ def fwd_kvcache( out, metadata.sm_scale, metadata.causal, + metadata.window_size_left, + metadata.window_size_right, metadata.alibi_slopes, metadata.layout, metadata.cache_seqlens, metadata.cache_batch_idx, ) - else: - softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( - q, - k_cache, - v_cache, - out, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - None, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.cache_seqlens, - metadata.cache_batch_idx, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - metadata.return_softmax, - USE_EXP2, - None, - None, - None, - None) - softmax_lse = softmax_lse_triton + softmax_lse = softmax_lse_triton if DEBUG: print("out:", out, out.shape) diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index f634103ca69..9cffda127a6 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -43,355 +43,6 @@ # ATOL_fp8, RTOL_fp8 = 2e-2, 2e-2 # fp8 EQUAL_NAN = True -@pytest.mark.parametrize( - "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - (1, 1, 1, 1, 1, 1), - (1, 1, 1, 2, 4, 16), - (1, 2, 2, 2, 4, 16), - (1, 4, 1, 2, 4, 16), - (1, 4, 2, 2, 4, 16), - (1, 1, 1, 4, 2, 16), - (1, 1, 1, 4, 4, 16), - (1, 2, 2, 4, 4, 16), - (2, 1, 1, 4, 4, 16), - (2, 2, 2, 4, 4, 16), - (1, 1, 1, 128, 64, 16), - (2, 2, 2, 2, 128, 1), - (2, 3, 3, 2, 128, 16), - (3, 2, 2, 256, 512, 16), - (3, 3, 3, 128, 128, 64), - (2, 4, 4, 1024, 1024, 64), - (4, 6, 6, 108, 256, 224), - (4, 8, 8, 2048, 2048, 128), - (4, 16, 16, 4096, 4096, 64), - (2, 4, 4, 8192, 8192, 32), - # fa configs - (4, 6, 1, 113, 203, 256), - (4, 6, 1, 128, 217, 256), - (4, 6, 2, 113, 211, 128), - (4, 6, 2, 108, 256, 128), - (4, 6, 1, 256, 512, 64), - (4, 6, 1, 512, 256, 64), - (4, 6, 2, 1024, 1024, 32), - (4, 6, 2, 1023, 1024, 32), - (4, 6, 6, 1024, 1023, 32), - (4, 6, 6, 2048, 2048, 32), - ], -) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('dropout_p', [0.0]) -@pytest.mark.parametrize('alibi_slopes', [None]) -@pytest.mark.parametrize('layout', ["bshd", "thd"]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize('use_exp2', [True, False]) # works when use_exp2 is false -@pytest.mark.parametrize('DEBUG_INPUT', [False]) # NOTE: debug input can overflow when the tensors are large. Just use to figure out issues -@pytest.mark.skip() -def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, alibi_slopes, layout, dtype, use_exp2, DEBUG_INPUT): - torch.manual_seed(42) - device = "cuda" - - q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, dtype, layout=layout, device=device) - - if DEBUG: - if HQ // HK != 1: - print("MQA/GQA") - else: - print("MHA") - - # update metadata - if causal: - metadata.need_causal(True) - - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - metadata.need_dropout(dropout_p, True) - - - # call Triton's forward implementation directly - q_triton = q.clone() - k_triton = k.clone() - v_triton = v.clone() - o_triton = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) - softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( - q_triton, - k_triton, - v_triton, - o_triton, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - metadata.bias, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.cache_seqlens, - metadata.cache_batch_idx, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - metadata.return_softmax, - use_exp2, - None, - None, - None, - None) - - # ref forward - q_ref = q.clone() - k_ref = k.clone() - v_ref = v.clone() - o_ref = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) - softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( - q_ref, - k_ref, - v_ref, - o_ref, - metadata.sm_scale, - metadata.alibi_slopes, - causal, - layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - use_exp2 - ) - - if DEBUG: - print() - print("Compare Triton Impl with refernce Pytorch Impl") - - # this can be set to true manually or when using dropout - if metadata.return_softmax: - if DEBUG: - print("sd_mask_triton:", sd_mask_triton, sd_mask_triton.shape) - print("sd_mask_ref:", sd_mask_ref, sd_mask_ref.shape) - torch.testing.assert_close(sd_mask_triton.to(sd_mask_ref.dtype), sd_mask_ref, atol=ATOL, rtol=RTOL) - - if DEBUG: - print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape) - print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape) - torch.testing.assert_close(softmax_lse_triton, softmax_lse_ref, atol=ATOL, rtol=RTOL) - - if DEBUG: - print("output_triton:", o_triton, o_triton.shape) - print("output_ref:", o_ref, o_ref.shape) - torch.testing.assert_close(o_triton, o_ref, atol=ATOL, rtol=RTOL) - -@pytest.mark.parametrize( - "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", [ - (1, 1, 1, 1, 1, 1), - (1, 1, 1, 4, 4, 4), - (2, 1, 1, 4, 4, 16), - (1, 2, 2, 4, 4, 16), - (1, 4, 1, 2, 4, 16), - (1, 8, 1, 2, 4, 16), - (1, 16, 1, 2, 4, 16), - (1, 32, 1, 2, 4, 16), - (1, 64, 1, 2, 4, 16), - (1, 4, 2, 2, 4, 16), - (2, 2, 2, 4, 4, 16), - (1, 1, 1, 4, 4, 16), - (2, 1, 1, 4, 4 , 16), - (4, 6, 6, 8, 8 , 16), - (1, 1, 1, 4, 4, 32), - (1, 1, 1, 16, 16, 16), - (1, 1, 1, 32, 32, 16), - (1, 1, 1, 64, 64, 16), - (1, 1, 1, 64, 64, 16), - (1, 1, 1, 64, 128, 16), - (1, 1, 1, 64, 64, 32), - (1, 1, 1, 64, 128, 32), - (1, 1, 1, 128, 128, 64), - (1, 1, 1, 128, 256, 45), - (1, 1, 1, 113, 203, 192), - (1, 1, 1, 256, 256, 64), - (1, 1, 1, 256, 512, 16), - (1, 1, 1, 512, 512, 64), - (1, 1, 1, 1024, 1024, 64), - # fa configs - (2, 2, 2, 128, 128, 65), - (2, 2, 2, 128, 128, 224), - (4, 6, 6, 108, 256, 224), - (1, 1, 1, 256, 512, 16), - # old tests that work - (4, 48, 6, 1024, 1024, 64), - (4, 48, 12, 2048, 1024, 64), - (4, 48, 24, 1024, 1024, 64), - (4, 48, 48, 1024, 1024, 64), - (4, 48, 48, 1024, 1024, 73), - (4, 48, 48, 2048, 2048, 64), - (1, 24, 24, 4096, 4096, 64), - (1, 16, 16, 1024, 1024, 64), - (1, 16, 16, 1024, 1024, 128), - # testcase new - # seqlen q == k - (1, 1, 1, 2, 2, 2), # small enough to debug - (1, 1, 1, 128, 128, 32), # only one block - (1, 1, 1, 127, 127, 32), # only one block but with masking - (1, 1, 1, 129, 129, 1), # two blocks with 2nd block small enough to debug - (1, 1, 1, 350, 350, 1), # two blocks with 2nd block small enough to debug - (1, 1, 1, 350, 350, 68), # generic masking on q, k and head - (4, 1, 1, 512, 512, 128), # batch > 1 - (4, 8, 2, 512, 512, 128), # GQA - (4, 8, 2, 512, 512, 68), # non-power-of-2 head_dim - (4, 8, 2, 500, 500, 68), # comprehensive case for seqlen q == k - # seqlen q > k - (1, 1, 1, 64, 32, 8), # seqlen_q > seqlen_k - (1, 1, 1, 192, 128, 32), # seqlen_q > seqlen_k - (4, 8, 2, 1024, 512, 68), # seqlen_q < seqlen_k - (1, 1, 1, 729, 516, 68), # seqlen_q > seqlen_k - (16, 16, 4, 2753, 1528, 68), # a comprehensive seqlen_q > seqlen_k - # seqlen q < k - (1, 1, 1, 32, 64, 8), # seqlen_q > seqlen_k - (1, 1, 1, 128, 192, 32), # seqlen_q < seqlen_k - (4, 8, 2, 512, 1024, 68), # seqlen_q < seqlen_k - (1, 1, 1, 200, 413, 1), # seqlen_q < seqlen_k - (1, 1, 1, 782, 1546, 1), # seqlen_q < seqlen_k - (16, 16, 4, 1528, 2753, 68), # a comprehensive seqlen_q < seqlen_k -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('dropout_p', [0.0]) -@pytest.mark.parametrize('alibi_slopes', [None]) -@pytest.mark.parametrize('layout', ["bshd", "thd"]) -@pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize('use_exp2', [False]) # FIXME: using exp2 causes issue when used with causal -@pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans on larger tensors -@pytest.mark.skip() -def test_op_prefill_bwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, alibi_slopes, layout, dtype, use_exp2, DEBUG_INPUT): - torch.manual_seed(20) - device="cuda" - - # gen inputs - q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, dtype, layout=layout, device=device) - - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - metadata.need_dropout(dropout_p, True) - - # =============================================== Reference ============================================================== - # fwd - q_ref = q.clone() - k_ref = k.clone() - v_ref = v.clone() - output_ref = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) - softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( - q_ref, - k_ref, - v_ref, - output_ref, - metadata.sm_scale, - metadata.alibi_slopes, - causal, - layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - use_exp2 - ) - - # bwd - do_ref = do.clone() - dq_ref = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) - dk_ref = torch.zeros_like(k).contiguous() if DEBUG_INPUT else torch.empty_like(k) - dv_ref = torch.zeros_like(v).contiguous() if DEBUG_INPUT else torch.empty_like(v) - delta_ref = attention_backward_pytorch_ref_impl( - do_ref, - q_ref, - k_ref, - v_ref, - output_ref, - softmax_lse_ref, - dq_ref, - dk_ref, - dv_ref, - metadata.sm_scale, - metadata.alibi_slopes, - causal, - layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - use_exp2 - ) - - # =============================================== Triton ============================================================== - do_triton = do.clone() - q_triton = q.clone() - k_triton = k.clone() - v_triton = v.clone() - o_triton = output_ref.clone().contiguous() - softmax_lse_triton = softmax_lse_ref.clone().contiguous() - dq_triton = torch.zeros_like(q_triton, dtype=q.dtype) # NOTE: the kernel does inplace accumlation on dq so dq has to be zeros - dk_triton = torch.zeros_like(k_triton, dtype=k.dtype) if DEBUG_INPUT else torch.empty_like(k_triton, dtype=k.dtype) - dv_triton = torch.zeros_like(v_triton, dtype=v.dtype) if DEBUG_INPUT else torch.empty_like(v_triton, dtype=v.dtype) - delta_triton = attention_prefill_backward_triton_split_fused_no_atomics_impl( - do_triton, - q_triton, - k_triton, - v_triton, - o_triton, - softmax_lse_triton, - dq_triton, - dk_triton, - dv_triton, - metadata.sm_scale, - alibi_slopes, - causal, - layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - use_exp2, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - # =============================================== Check ============================================================== - if DEBUG: - print() - if DEBUG: - print("delta_triton:", delta_triton, delta_triton.shape) - print("delta_ref:", delta_ref, delta_ref.shape) - torch.testing.assert_close(delta_triton, delta_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) - - if DEBUG: - print("dv_triton:", dv_triton, dv_triton.shape) - print("dv_ref:", dv_ref, dv_ref.shape) - torch.testing.assert_close(dv_triton, dv_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) - - if DEBUG: - print("dk_triton:", dk_triton, dk_triton.shape) - print("dk_ref:", dk_ref, dk_ref.shape) - torch.testing.assert_close(dk_triton, dk_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) - - if DEBUG: - print("dq_triton:", dq_triton, dq_triton.shape) - print("dq_ref:", dq_ref, dq_ref.shape) - torch.testing.assert_close(dq_triton, dq_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) - def fp8_assert_close(tensor_a, tensor_b, atol=ATOL_fp8, rtol=RTOL_fp8, max_diff_percentage=0.5): """Assert tensors are close with tolerance for small percentage of elements""" # standard comparison @@ -523,7 +174,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac # test apis if packing == 'qkv': # generate inputs - qkv, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, packing=packing, device=device, DEBUG_INPUT=DEBUG_INPUT) + qkv, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, packing=packing, device=device) # ---------------------------------------------------------------- # --- FP8 --- @@ -631,7 +282,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac elif packing is None: # generate inputs - q, k, v, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) + q, k, v, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, device=device) # ---------------------------------------------------------------- # --- FP8 --- @@ -951,6 +602,7 @@ def clear_compile_cache(): # (16, 48, 16, 256, 512, 64), # MQA test (HQ > HK) ], ) +@pytest.mark.skip() # this works or not depending on the torch and triton version compatibility. Keep the test but use it in docker containers where things are in sync def test_torch_compile(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD): print(f"\n\nTesting with BATCH={BATCH}, HQ={HQ}, HK={HK}, N_CTX_Q={N_CTX_Q}, N_CTX_K={N_CTX_K}, D_HEAD={D_HEAD}") diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 1795e0d1366..96d4f662567 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -6,7 +6,8 @@ import functools import triton import triton.language as tl -from typing import Literal, Optional, Union +import numpy as np +from typing import Literal, Optional # ------------------------------- # Gloabl Variables @@ -42,7 +43,7 @@ class MetaData(): num_contexts = 0 varlen: bool = False layout: Optional[Literal["bshd", "bhsd", "thd"]] = None - cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None + cache_seqlens: Optional[torch.Tensor] = None cache_batch_idx = None packing: Optional[bool] = None return_softmax: bool = False @@ -54,6 +55,8 @@ class MetaData(): rotary_cos: Optional[torch.Tensor] = None rotary_interleaved: bool = False rotary_conjunction: bool = False + window_size_left: int = -1 + window_size_right: int = -1 def __repr__(self) -> str: @@ -73,6 +76,8 @@ def __repr__(self) -> str: f" cache_batch_idx={self.cache_batch_idx},\n" f" dropout_p={self.dropout_p},\n" f" return_softmax={self.return_softmax}\n" + f" window_size_left={self.window_size_left},\n" + f" window_size_right={self.window_size_right},\n" f")") def __init__(self, sm_scale=1.0): @@ -166,7 +171,7 @@ def generate_varlen_tensor( equal_seqlens: bool = False, device: str = "cuda", dtype: torch.dtype = torch.float16, - DEBUG_INPUT: bool = False + mode: Literal["random", "ones", "incremental", "identity"] = "random" ): if DEBUG: print("total_seqlen", total_seqlen) @@ -200,8 +205,8 @@ def generate_varlen_tensor( cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)]).to(torch.int32).to(device=device) max_seqlen = torch.max(seqlens).to(torch.int32).item() - # create varlen tensor - if DEBUG_INPUT: + # create varlen tensor based on mode + if mode == "incremental": x = torch.zeros(total_seqlen, num_heads, head_size, dtype=dtype, device=device) for i in range(batch_size): start = cu_seqlens[i].item() @@ -213,8 +218,23 @@ def generate_varlen_tensor( .view(length, 1, 1) .expand(length, num_heads, head_size) ) - else: + elif mode == "identity": + x = torch.zeros(total_seqlen, num_heads, head_size, dtype=dtype, device=device) + # for each batch, create identity pattern within that batch's sequence + for i in range(batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i+1].item() + length = end - start + + # create identity pattern for positions within this batch + for pos in range(min(length, head_size)): + x[start + pos, :, pos] = 1.0 + elif mode == "random": x = torch.randn((total_seqlen, num_heads, head_size), dtype=dtype, device=device) + elif mode == "ones": + x = torch.ones((total_seqlen, num_heads, head_size), dtype=dtype, device=device) + else: + raise ValueError(f"Unkown mode {mode}") if is_fp8_dtype: # cast to fp8 @@ -225,19 +245,28 @@ def generate_varlen_tensor( x.requires_grad_() return x, cu_seqlens, max_seqlen -def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", DEBUG_INPUT=False): +def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", mode: Literal["random", "ones", "incremental", "identity"] = "random"): # save fp8 type is_fp8_dtype = is_dtype_fp8(dtype) if is_fp8_dtype: og_fp8_dtype = dtype dtype = torch.float32 - # gen tensor + # gen tensor based on mode tensor_shape = (BATCH, SEQ_LEN, NUM_HEADS, D_HEAD) - if DEBUG_INPUT: + if mode == "incremental": x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, SEQ_LEN, 1, 1).expand(*tensor_shape).contiguous() - else: + elif mode == "identity": + x = torch.zeros(tensor_shape, dtype=dtype, device=device) + # create identity pattern: position i has value 1 at dimension i + for i in range(min(SEQ_LEN, D_HEAD)): + x[:, i, :, i] = 1.0 + elif mode == "random": x = torch.randn(tensor_shape, dtype=dtype, device=device) + elif mode == "ones": + x = torch.ones(tensor_shape, dtype=dtype, device=device) + else: + raise ValueError(f"Unkown mode {mode}") if is_fp8_dtype: # cast to fp8 @@ -248,26 +277,31 @@ def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtype = x.requires_grad_() return x -def generate_bhsd_tensor(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", DEBUG_INPUT=False): +def generate_bhsd_tensor(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", mode: Literal["random", "ones", "incremental", "identity"] = "random"): # save fp8 type is_fp8_dtype = is_dtype_fp8(dtype) if is_fp8_dtype: og_fp8_dtype = dtype dtype = torch.float32 - # gen tensor + # gen tensor based on mode tensor_shape = (BATCH, NUM_HEADS, SEQ_LEN, D_HEAD) - if DEBUG_INPUT: + if mode == "incremental": x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, 1, SEQ_LEN, 1).expand(*tensor_shape).contiguous() - else: + elif mode == "identity": + x = torch.zeros(tensor_shape, dtype=dtype, device=device) + # create identity pattern: position i has value 1 at dimension i + for i in range(min(SEQ_LEN, D_HEAD)): + x[:, :, i, i] = 1.0 + elif mode == "random": x = torch.randn(tensor_shape, dtype=dtype, device=device) + elif mode == "ones": + x = torch.ones(tensor_shape, dtype=dtype, device=device) + else: + raise ValueError(f"Unkown mode {mode}") - if is_fp8_dtype: - # cast to fp8 - x, descale_x = cast_to_fp8(x, og_fp8_dtype, "bhsd") # FIXME: I don't the casting fn supports this atm - x.requires_grad_() - return x, descale_x + raise ValueError("fp8 not supported for bhsd yet") else: x.requires_grad_() return x @@ -513,8 +547,7 @@ def input_helper( dtype: torch.dtype, layout: Literal["bshd", "bhsd", "thd"], packing: Optional[Literal["kv", "qkv"]] = None, - device: Literal["cpu", "cuda"] = "cuda", - DEBUG_INPUT: bool = False, + device: Literal["cpu", "cuda"] = "cuda" ): torch.manual_seed(20) is_fp8_dtype = is_dtype_fp8(dtype) @@ -529,23 +562,23 @@ def input_helper( if packing is None: # gen tensors if is_fp8_dtype: - q, cu_seqlens_q, max_seqlen_q, descale_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - k, cu_seqlens_k, max_seqlen_k, descale_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - v, _, _, descale_v = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + q, cu_seqlens_q, max_seqlen_q, descale_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + k, cu_seqlens_k, max_seqlen_k, descale_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + v, _, _, descale_v = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) do, _, _, descale_do = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) else: - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - v, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + v, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + do, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) elif packing == "kv": # gen tensors with kv packing if is_fp8_dtype: raise ValueError("FP8 not supported for KV packing yet") else: - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - kv, cu_seqlens_k, max_seqlen_k = generate_varlen_kv_packed(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + kv, cu_seqlens_k, max_seqlen_k = generate_varlen_kv_packed(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + do, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) elif packing == "qkv": # qkv packing - requires same sequence length for q and k assert N_CTX_Q == N_CTX_K, "For QKV packing, Q and K must have same sequence length" @@ -554,17 +587,13 @@ def input_helper( if is_fp8_dtype: raise ValueError("FP8 not supported for QKV packing yet") else: - qkv, cu_seqlens_q, max_seqlen_q = generate_varlen_qkv_packed(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + qkv, cu_seqlens_q, max_seqlen_q = generate_varlen_qkv_packed(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) cu_seqlens_k = cu_seqlens_q max_seqlen_k = max_seqlen_q - # create dummy do for qkv case - do = torch.ones((TOTAL_SEQLENS_Q, HQ, D_HEAD), dtype=dtype, device=device) if DEBUG_INPUT else torch.randn((TOTAL_SEQLENS_Q, HQ, D_HEAD), dtype=dtype, device=device) + do, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) # setup metadata - if DEBUG_INPUT: - sm_scale = 1 - else: - sm_scale = D_HEAD**-0.5 + sm_scale = D_HEAD**-0.5 metadata = MetaData(sm_scale=sm_scale) metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) metadata.need_causal(CAUSAL) @@ -576,39 +605,38 @@ def input_helper( # gen tensors if layout == "bshd": if is_fp8_dtype: - q, descale_q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - k, descale_k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - v, descale_v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + q, descale_q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) + k, descale_k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device) + v, descale_v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device) do, descale_do = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) else: - q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) + k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device) + v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device) + do = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) elif layout == "bhsd": - if is_fp8_dtype: - q, descale_q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - k, descale_k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - v, descale_v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + q, descale_q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) + k, descale_k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device) + v, descale_v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device) do, descale_do = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) - else: - q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + else: + q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) + k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device) + v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device) + do = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) elif packing == "kv": # gen tensors with kv packing if is_fp8_dtype: raise ValueError("FP8 not supported for KV packing yet") else: if layout == "bshd": - q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - kv = generate_bshd_kv_packed(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) + kv = generate_bshd_kv_packed(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device) + do = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) elif layout == "bhsd": - q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - kv = generate_bhsd_kv_packed(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) + kv = generate_bhsd_kv_packed(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device) + do = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) elif packing == "qkv": # qkv packing - requires same sequence length for q and k assert N_CTX_Q == N_CTX_K, "For QKV packing, Q and K must have same sequence length" @@ -618,17 +646,14 @@ def input_helper( raise ValueError("FP8 not supported for QKV packing yet") else: if layout == "bshd": - qkv = generate_bshd_qkv_packed(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones((BATCH, N_CTX_Q, HQ, D_HEAD), dtype=dtype, device=device) if DEBUG_INPUT else torch.randn((BATCH, N_CTX_Q, HQ, D_HEAD), dtype=dtype, device=device) + qkv = generate_bshd_qkv_packed(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) + do = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) elif layout == "bhsd": - qkv = generate_bhsd_qkv_packed(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - do = torch.ones((BATCH, HQ, N_CTX_Q, D_HEAD), dtype=dtype, device=device) if DEBUG_INPUT else torch.randn((BATCH, HQ, N_CTX_Q, D_HEAD), dtype=dtype, device=device) + qkv = generate_bhsd_qkv_packed(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) + do = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) # setup metadata - if DEBUG_INPUT: - sm_scale = 1 - else: - sm_scale = D_HEAD**-0.5 + sm_scale = D_HEAD**-0.5 metadata = MetaData(sm_scale=sm_scale) metadata.max_seqlens_q = N_CTX_Q metadata.max_seqlens_k = N_CTX_K @@ -957,6 +982,29 @@ def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k): def round_multiple(x, m): return (x + m - 1) // m * m +def save_tensor_to_csv(tensor, filename, decimal_places=2): + """ + save a 2d tensor to csv file + + args: + tensor: torch tensor of shape [rows, cols] + filename: output csv filename + decimal_places: number of decimal places (default: 2) + """ + # ensure tensor is 2d + if tensor.ndim != 2: + raise ValueError(f"tensor must be 2d, got shape {tensor.shape}") + + # ensure filename ends with .csv + if not filename.endswith('.csv'): + filename = filename + '.csv' + + # save to csv using numpy + np.savetxt(filename, + tensor.detach().cpu().numpy(), + delimiter=',', + fmt=f'%.{decimal_places}f') + # ------------------------------- # Dropouts # ------------------------------- @@ -1016,6 +1064,91 @@ def write_dropout_mask(x, tensor_name = "tensor"): else: writer.writerows(dropout_mask) +# ------------------------------- +# Autotune +# ------------------------------- +def get_fwd_prefill_cdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] + + +def get_fwd_prefill_rdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] + + +def get_fwd_prefill_autotune_configs(): + if AUTOTUNE: + if is_rdna(): + return get_fwd_prefill_rdna_autotune_configs() + elif is_cdna(): + return get_fwd_prefill_cdna_autotune_configs() + else: + raise ValueError("Unknown Device Type") + else: + arch = get_arch() + if arch == "gfx950": + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + elif arch == "gfx942" and False: # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + else: + default_config = triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + + return [ + default_config + ], [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL", + "IS_VARLEN", + "HQ", + "HK", + ] + # ------------------------------- # Runtime info # ------------------------------- diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index 6073cb1c35a..adebe088a79 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -16,7 +16,7 @@ from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb -from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_hip, is_rdna +from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_hip, is_rdna, generate_bshd_tensor, save_tensor_to_csv MAX_HEADDIM_SM8x = 192 @@ -573,7 +573,7 @@ def get_dropout_fraction( # @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) @@ -712,6 +712,10 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) + if local == True: + print("Sliding Window not supported in backward yet") + return + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() @@ -722,7 +726,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [True]) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @@ -860,6 +864,10 @@ def test_flash_attn_varlen_qkvpacked( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) + if local == True: + print("Sliding Window not supported in backward yet") + return + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() @@ -874,7 +882,7 @@ def test_flash_attn_varlen_qkvpacked( # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @@ -1133,6 +1141,10 @@ def test_flash_attn_output( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) + if local == True: + print("Sliding Window not supported in backward yet") + return + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() @@ -1149,7 +1161,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [True]) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [True]) @@ -1455,6 +1467,10 @@ def test_flash_attn_varlen_output( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04) + if local == True: + print("Sliding Window not supported in backward yet") + return + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() @@ -1463,7 +1479,7 @@ def test_flash_attn_varlen_output( @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -1569,6 +1585,10 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 + if local == True: + print("Sliding Window not supported in backward yet") + return + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 @@ -1576,7 +1596,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -1737,6 +1757,10 @@ def test_flash_attn_varlen_causal( # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 + if local == True: + print("Sliding Window not supported in backward yet") + return + if test_backward: assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 @@ -1749,7 +1773,7 @@ def test_flash_attn_varlen_causal( # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [True]) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @@ -1867,6 +1891,10 @@ def test_flash_attn_splitkv( # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 + if local == True: + print("Sliding Window not supported in backward yet") + return + mult = 2 if not alibi else 8 assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 @@ -1883,7 +1911,7 @@ def test_flash_attn_splitkv( # @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) @@ -2404,7 +2432,7 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @@ -2452,6 +2480,10 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True) + if local == True: + print("Sliding Window not supported in backward yet") + return + g = torch.randn_like(out) dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) for _ in range(50): @@ -2463,7 +2495,7 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @@ -2540,6 +2572,10 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus deterministic=True, ) + if local == True: + print("Sliding Window not supported in backward yet") + return + g = torch.randn_like(out) dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) for _ in range(50): From b3c5204cc847ae6ee94f0b5ea4dcb481df0dc7a5 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Mon, 14 Jul 2025 14:51:11 -0400 Subject: [PATCH 06/27] Fix Device Segfault (#152) * Compress segfault work fix backward segfault rework offset ignore .profile ignore .analysis save * assert the kernel launch device and tensor devices are the same * fix failing asserts * add asserts to fwd --- .gitignore | 1 + .../bwd_prefill_fused_no_atomics.py | 234 +++++++++++------- .../flash_attn_triton_amd/fwd_prefill.py | 175 ++++++++----- 3 files changed, 263 insertions(+), 147 deletions(-) diff --git a/.gitignore b/.gitignore index 9559348b043..9dddf555d43 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,4 @@ training/data # ck modules csrc/composable_kernel csrc/cutlass +.analysis \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py index 8bdcfd10d6a..5b2f8858d11 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py @@ -99,8 +99,7 @@ def get_autotune_configs(): # Here is the I/O shape: # Out: (batch, nhead_q, max_seqlens_q, headDim) # DO: (batch, nhead_q, max_seqlens_q, headDim) -# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at -# fwd_prefill.py line 607 +# Delta: (batch, nheads_q, max_seqlens_q) @triton.autotune( configs=preprocess_autotune_configs, key=preprocess_autotune_keys, @@ -108,9 +107,11 @@ def get_autotune_configs(): ) @triton.jit def _bwd_preprocess( - O, DO, # noqa: E741 + O, + DO, # noqa: E741 Delta, stride_ob, stride_oh, stride_om, stride_od, + stride_dob, stride_doh, stride_dom, stride_dod, stride_delta_b, stride_delta_h, stride_delta_m, stride_descale_do_z, cu_seqlens_q, max_seqlen_q, @@ -125,8 +126,6 @@ def _bwd_preprocess( bid = tl.program_id(1) hid = tl.program_id(2) # Handle varlen - q_start = 0 - seqlen_q = max_seqlen_q if IS_VARLEN: q_start = tl.load(cu_seqlens_q + bid) q_end = tl.load(cu_seqlens_q + bid + 1) @@ -138,32 +137,41 @@ def _bwd_preprocess( # Compute offsets offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) offs_d = tl.arange(0, HEAD_DIM) - # Offset O/DO by batch, head and q_start - O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741 - DO += bid * stride_ob + hid * stride_oh + q_start * stride_om + # pointer offsets for O & DO + off_o = ( bid * stride_ob + + hid * stride_oh + + q_start * stride_om + + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_od) # noqa: E741 + off_do = (bid * stride_dob + + hid * stride_doh + + q_start * stride_dom + + offs_m[:, None] * stride_dom + + offs_d[None, :] * stride_dod) + # create masks mask_m = offs_m < seqlen_q mask_md = mask_m[:, None] PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) if PADDED_HEAD: mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM - # compute pointers - offs_do = offs_m[:, None] * stride_om + offs_d[None, :] * stride_od - out_ptrs = O + offs_do - do_ptrs = DO + offs_do # load - o = tl.load(out_ptrs, mask=mask_md, other=0.0) - do = tl.load(do_ptrs, mask=mask_md, other=0.0) + o = tl.load(O + off_o, mask=mask_md, other=0.0) + do = tl.load(DO + off_do, mask=mask_md, other=0.0) # compute and write-back to delta if IS_FP8: - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hid) + off_descale_do = bid * stride_descale_do_z + hid + descale_do = tl.load(Descale_do + off_descale_do) # NOTE: do is in the fp8 range and o is not in fp8 delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) else: delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - delta_offset = Delta + bid * stride_delta_b + hid * stride_delta_h + q_start * stride_delta_m - tl.store(delta_offset + offs_m * stride_delta_m, delta, mask=mask_m) + off_delta = (bid * stride_delta_b + + hid * stride_delta_h + + q_start * stride_delta_m + + offs_m * stride_delta_m) + tl.store(Delta + off_delta , delta, mask=mask_m) # The main inner-loop logic for computing dK and dV. @@ -1063,8 +1071,9 @@ def is_contiguous(x, name): print(f"{name} is not contiguous") return x.contiguous() - -OLD_LSE = os.environ.get('OLD_LSE', '0').lower() in ('1', 'true', 'yes') +OLD_LSE: bool = False +DEBUG_TRITON: bool = False +DEBUG_TRITON_DETAIL: bool = False def attention_prefill_backward_triton_split_fused_no_atomics_impl( do: torch.Tensor, @@ -1098,25 +1107,118 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( descale_dk: Optional[torch.Tensor], descale_dv: Optional[torch.Tensor], ): - # debug - DEBUG_TRITON: bool = False - DEBUG_TRITON_DETAIL: bool = False - - # do = is_contiguous(do, "do") - # q = is_contiguous(q, "q") - # k = is_contiguous(k, "k") - # v = is_contiguous(v, "v") - # o = is_contiguous(o, "o") - # softmax_lse = is_contiguous(softmax_lse, "softmax_lse") - # dq = is_contiguous(dq, "dq") - # dk = is_contiguous(dk, "dk") - # dv = is_contiguous(dv, "dv") + # get params, strides and shape + IS_VARLEN = layout == "thd" + use_dropout = (dropout_p > 0.0) + + # common assertions + assert 0.0 <= dropout_p <= 1.0, f"dropout_p must be between 0 and 1, got {dropout_p}" + assert q.device == k.device == v.device == o.device == do.device == softmax_lse.device, \ + f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}, do={do.device}, softmax_lse={softmax_lse.device}" + assert q.dtype == k.dtype == v.dtype == do.dtype, "q, k, v, do must have the same dtype" + current_device = torch.cuda.current_device() + assert q.is_cuda and q.device.index == current_device, f"Device mismatch: Kernel will launch on cuda:{current_device}, but tensors are on {q.device}" + # get shapes and strides + if IS_VARLEN: + # shape + total_seqlen_q, nheads_q, head_size_q = q.shape + total_seqlen_k, nheads_k, head_size_k = k.shape + total_seqlen_v, nheads_v, head_size_v = v.shape + nheads_lse, total_seqlen_lse = softmax_lse.shape + + # assert shapes + assert total_seqlen_lse == total_seqlen_q, f"softmax_lse seqlen {total_seqlen_lse} != q seqlen {total_seqlen_q}" + assert cu_seqlens_q is not None, "cu_seqlens_q must be provided for varlen layout" + assert cu_seqlens_k is not None, "cu_seqlens_k must be provided for varlen layout" + assert max_seqlen_q is not None, "max_seqlen_q must be provided for varlen layout" + assert max_seqlen_k is not None, "max_seqlen_k must be provided for varlen layout" + + # assert head dimensions + assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + assert nheads_lse == nheads_q, f"softmax_lse heads {nheads_lse} != q heads {nheads_q}" + + # assert output shapes + assert o.shape == (total_seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_q)}" + assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" + assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}" + + # assert cu_seqlens + assert cu_seqlens_q.dtype == torch.int32, f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" + assert cu_seqlens_k.dtype == torch.int32, f"cu_seqlens_k must be int32, got {cu_seqlens_k.dtype}" + assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" + assert cu_seqlens_k[0] == 0, "cu_seqlens_k must start with 0" + assert cu_seqlens_q[-1] == total_seqlen_q, f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} != total_seqlen_q {total_seqlen_q}" + assert cu_seqlens_k[-1] == total_seqlen_k, f"cu_seqlens_k[-1] {cu_seqlens_k[-1]} != total_seqlen_k {total_seqlen_k}" + + # set vars + batch = len(cu_seqlens_q) - 1 + head_size = head_size_q + + # strides + stride_qb, stride_qm, stride_qh, stride_qd = 0, q.stride(0), q.stride(1), q.stride(2) + stride_kb, stride_kn, stride_kh, stride_kd = 0, k.stride(0), k.stride(1), k.stride(2) + stride_vb, stride_vn, stride_vh, stride_vd = 0, v.stride(0), v.stride(1), v.stride(2) + stride_ob, stride_om, stride_oh, stride_od = 0, o.stride(0), o.stride(1), o.stride(2) + stride_dqb, stride_dqm, stride_dqh, stride_dqd = 0, dq.stride(0), dq.stride(1), dq.stride(2) + stride_dkb, stride_dkn, stride_dkh, stride_dkd = 0, dk.stride(0), dk.stride(1), dk.stride(2) + stride_dvb, stride_dvn, stride_dvh, stride_dvd = 0, dv.stride(0), dv.stride(1), dv.stride(2) + stride_dob, stride_dom, stride_doh, stride_dod = 0, do.stride(0), do.stride(1), do.stride(2) + stride_lse_b, stride_lse_h, stride_lse_m = (0, softmax_lse.stride(0), softmax_lse.stride(1)) + else: + # shapes + batch_q, seqlen_q, nheads_q, head_size_q = q.shape + batch_k, seqlen_k, nheads_k, head_size_k = k.shape + batch_v, seqlen_v, nheads_v, head_size_v = v.shape + batch_lse, nheads_lse, seqlen_lse = softmax_lse.shape + + # assert batch dimensions + assert batch_q == batch_k == batch_v, f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" + + # assert head dimensions + assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + + # assert sequence lengths + assert seqlen_k == seqlen_v, f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" + + # assert output shapes + assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected" + assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" + assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}" + + # assert softmax_lse shape + assert softmax_lse.shape == (batch_q, nheads_q, seqlen_q), f"softmax_lse shape {softmax_lse.shape} != expected" + + # set vars + batch = batch_q + head_size = head_size_q + max_seqlen_q = seqlen_q + max_seqlen_k = seqlen_k + + # strides + stride_qb, stride_qm, stride_qh, stride_qd = q.stride() + stride_kb, stride_kn, stride_kh, stride_kd = k.stride() + stride_vb, stride_vn, stride_vh, stride_vd = v.stride() + stride_ob, stride_om, stride_oh, stride_od = o.stride() + stride_dqb, stride_dqm, stride_dqh, stride_dqd = dq.stride() + stride_dkb, stride_dkn, stride_dkh, stride_dkd = dk.stride() + stride_dvb, stride_dvn, stride_dvh, stride_dvd = dv.stride() + stride_dob, stride_dom, stride_doh, stride_dod = do.stride() + stride_lse_b, stride_lse_h, stride_lse_m = softmax_lse.stride() + + # fp8 setup - moved after all assertions IS_FP8 = is_fp8(q) if IS_FP8: FP8_MAX = torch.finfo(q.dtype).max - # assert that the main inputs are fp8 - assert is_fp8(do) and is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: do.dtype={do.dtype}, q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8." + # we already asserted that do, q, k, v all have the same dtype, so no need to check each one if is_fp8(o): FP8_OUTPUT = True assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." @@ -1136,45 +1238,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( FP8_OUTPUT = False stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = stride_descale_do_z = None - - # get params, strides and shape - IS_VARLEN = layout == "thd" - use_dropout = (dropout_p > 0.0) - - # get shapes and strides - if IS_VARLEN: - # shape - _, nheads_q, head_size = q.shape - _, nheads_k, _ = k.shape - batch = len(cu_seqlens_q) - 1 - max_seqlen_q_final = max_seqlen_q - max_seqlen_k_final = max_seqlen_k - - # strides - stride_qb, stride_qh, stride_qm, stride_qd = 0, q.stride(1), q.stride(0), q.stride(2) - stride_kb, stride_kh, stride_kn, stride_kd = 0, k.stride(1), k.stride(0), k.stride(2) - stride_vb, stride_vh, stride_vn, stride_vd = 0, v.stride(1), v.stride(0), v.stride(2) - stride_ob, stride_oh, stride_om, stride_od = 0, o.stride(1), o.stride(0), o.stride(2) - stride_dqb, stride_dqh, stride_dqm, stride_dqd = 0, dq.stride(1), dq.stride(0), dq.stride(2) - stride_dkb, stride_dkh, stride_dkn, stride_dkd = 0, dk.stride(1), dk.stride(0), dk.stride(2) - stride_dvb, stride_dvh, stride_dvn, stride_dvd = 0, dv.stride(1), dv.stride(0), dv.stride(2) - stride_dob, stride_doh, stride_dom, stride_dod = 0, do.stride(1), do.stride(0), do.stride(2) - stride_lse_b, stride_lse_h, stride_lse_m = (0, softmax_lse.stride(0), softmax_lse.stride(1)) - else: - # shapes - batch, max_seqlen_q_final, nheads_q, head_size = q.shape - _, max_seqlen_k_final, nheads_k, _ = k.shape - - # strides - stride_qb, stride_qh, stride_qm, stride_qd = q.stride(0), q.stride(2), q.stride(1), q.stride(3) - stride_kb, stride_kh, stride_kn, stride_kd = k.stride(0), k.stride(2), k.stride(1), k.stride(3) - stride_vb, stride_vh, stride_vn, stride_vd = v.stride(0), v.stride(2), v.stride(1), v.stride(3) - stride_ob, stride_oh, stride_om, stride_od = o.stride(0), o.stride(2), o.stride(1), o.stride(3) - stride_dqb, stride_dqh, stride_dqm, stride_dqd = dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3) - stride_dkb, stride_dkh, stride_dkn, stride_dkd = dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3) - stride_dvb, stride_dvh, stride_dvn, stride_dvd = dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3) - stride_dob, stride_doh, stride_dom, stride_dod = do.stride(0), do.stride(2), do.stride(1), do.stride(3) - stride_lse_b, stride_lse_h, stride_lse_m = softmax_lse.stride() + # alibi setup use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) # get closest power of 2 over or equal to 32. @@ -1193,28 +1257,28 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( else: if IS_VARLEN: # interface expects the varlen sequence dims to rounded like this. Not sure why. - batch_size = cu_seqlens_q.numel() - 1 total_q, num_heads, _ = q.shape - total_q_rounded = total_q + 128 * batch_size + total_q_rounded = total_q + 128 * batch delta_padded = torch.zeros((nheads_q, total_q_rounded), device=q.device, dtype=torch.float32) delta = delta_padded[:, :total_q] stride_delta_b, stride_delta_h, stride_delta_m = 0, delta.stride(0), delta.stride(1) else: # the interface expects the sequence dimension to be rounded to 128 - max_seqlen_q_rounded = round_multiple(max_seqlen_q_final, 128) + max_seqlen_q_rounded = round_multiple(max_seqlen_q, 128) delta_padded = torch.zeros((batch, nheads_q, max_seqlen_q_rounded), - device=softmax_lse.device, dtype=torch.float32) - delta = delta_padded[:, :, :max_seqlen_q_final] + device=q.device, dtype=torch.float32) + delta = delta_padded[:, :, :max_seqlen_q] stride_delta_b, stride_delta_h, stride_delta_m = delta.stride() - pre_grid = lambda META: (triton.cdiv(max_seqlen_q_final, META['PRE_BLOCK']), batch, nheads_q) + pre_grid = lambda META: (triton.cdiv(max_seqlen_q, META['PRE_BLOCK']), batch, nheads_q) _bwd_preprocess[pre_grid]( o, do, delta, stride_ob, stride_oh, stride_om, stride_od, + stride_dob, stride_doh, stride_dom, stride_dod, stride_delta_b, stride_delta_h, stride_delta_m, stride_descale_do_z, - cu_seqlens_q, max_seqlen_q_final, + cu_seqlens_q, max_seqlen_q, descale_do, HEAD_DIM=HEAD_DIM, ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, @@ -1232,7 +1296,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( (0, 0 , 0 , 0) if use_dropout: dropout_mask = torch.zeros( - (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), + (batch, nheads_q, max_seqlen_q, max_seqlen_k), device=q.device, dtype=torch.float32 ) @@ -1241,7 +1305,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( if not IS_VARLEN: dropout_mask = create_dropout_mask( dropout_p, - (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), + (batch, nheads_q, max_seqlen_q, max_seqlen_k), seed = philox_seed ) else: @@ -1252,7 +1316,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ dropout_mask.stride() - seqlen = max(max_seqlen_q_final, max_seqlen_k_final) + seqlen = max(max_seqlen_q, max_seqlen_k) grid = lambda META: (nheads_k, (seqlen + META['BLOCK_N1'] - 1) // META['BLOCK_N1'], batch, ) if causal: if DEBUG_TRITON: print(f"bwd_kernel: grid = {grid}" ) # noqa: E701 @@ -1273,7 +1337,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, + max_seqlen_q, max_seqlen_k, dropout_mask, dropout_p, philox_seed, philox_offset, alibi_slopes, descale_q, descale_k, descale_v, descale_do, @@ -1307,7 +1371,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, + max_seqlen_q, max_seqlen_k, dropout_mask, dropout_p, philox_seed, philox_offset, alibi_slopes, descale_q, descale_k, descale_v, descale_do, diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 59fe8bfaf4e..b646b486ce1 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -526,19 +526,14 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr, FLIP_GRID: tl.constexpr): + IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr): # set params ACCUMULATOR_TYPE = tl.float32 # compute offsets - if FLIP_GRID: - off_z = tl.program_id(0) - off_h_q = tl.program_id(1) - start_m = tl.program_id(2) - else: - start_m = tl.program_id(0) - off_h_q = tl.program_id(1) - off_z = tl.program_id(2) + off_z = tl.program_id(0) + off_h_q = tl.program_id(1) + start_m = tl.program_id(2) # If MQA / GQA, set the K and V head offsets appropriately. GROUP_SIZE: tl.constexpr = HQ // HK if GROUP_SIZE != 1: @@ -899,19 +894,116 @@ def attention_prefill_forward_triton_impl( descale_v: Optional[torch.Tensor], descale_o: Optional[torch.Tensor], ): + # get params, strides and shape + IS_VARLEN = layout == "thd" + + # common assertions + assert 0.0 <= dropout_p <= 1.0, f"dropout_p must be between 0 and 1, got {dropout_p}" + assert q.device == k.device == v.device == o.device, \ + f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}" + assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype" + current_device = torch.cuda.current_device() + assert q.is_cuda and q.device.index == current_device, f"Device mismatch: Kernel will launch on cuda:{current_device}, but tensors are on {q.device}" + + # get shapes and strides + if IS_VARLEN: + # shape + total_seqlen_q, nheads_q, head_size_q = q.shape + total_seqlen_k, nheads_k, head_size_k = k.shape + total_seqlen_v, nheads_v, head_size_v = v.shape + + # assert shapes + assert cu_seqlens_q is not None, "cu_seqlens_q must be provided for varlen layout" + assert cu_seqlens_k is not None, "cu_seqlens_k must be provided for varlen layout" + assert max_seqlens_q is not None and max_seqlens_q > 0, "max_seqlens_q must be provided and positive for varlen layout" + assert max_seqlens_k is not None and max_seqlens_k > 0, "max_seqlens_k must be provided and positive for varlen layout" + + # assert head dimensions + assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + + # assert output shapes + assert o.shape == (total_seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_q)}" + + # assert cu_seqlens + assert cu_seqlens_q.dtype == torch.int32, f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" + assert cu_seqlens_k.dtype == torch.int32, f"cu_seqlens_k must be int32, got {cu_seqlens_k.dtype}" + assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" + assert cu_seqlens_k[0] == 0, "cu_seqlens_k must start with 0" + assert cu_seqlens_q[-1] == total_seqlen_q, f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} != total_seqlen_q {total_seqlen_q}" + assert cu_seqlens_k[-1] == total_seqlen_k, f"cu_seqlens_k[-1] {cu_seqlens_k[-1]} != total_seqlen_k {total_seqlen_k}" + + # set vars + batch = len(cu_seqlens_q) - 1 + head_size = head_size_q + + # softmax_lse shape + softmax_lse = torch.zeros((nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32) + + # strides + stride_qb, stride_qh, stride_qm, stride_qd = 0, q.stride(1), q.stride(0), q.stride(2) + stride_kb, stride_kh, stride_kn, stride_kd = 0, k.stride(1), k.stride(0), k.stride(2) + stride_vb, stride_vh, stride_vn, stride_vd = 0, v.stride(1), v.stride(0), v.stride(2) + stride_ob, stride_oh, stride_om, stride_od = 0, o.stride(1), o.stride(0), o.stride(2) + stride_lse_z, stride_lse_h, stride_lse_m = 0, softmax_lse.stride(0), softmax_lse.stride(1) + else: + # shapes + batch_q, seqlen_q, nheads_q, head_size_q = q.shape + batch_k, seqlen_k, nheads_k, head_size_k = k.shape + batch_v, seqlen_v, nheads_v, head_size_v = v.shape + + # assert batch dimensions + assert batch_q == batch_k == batch_v, f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" + + # assert head dimensions + assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + + # assert sequence lengths + assert seqlen_k == seqlen_v, f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" + + # assert output shapes + assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(batch_q, seqlen_q, nheads_q, head_size_q)}" + + # set vars + batch = batch_q + head_size = head_size_q + max_seqlens_q = seqlen_q + max_seqlens_k = seqlen_k + + # softmax_lse shape + softmax_lse = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) + + # strides + stride_qb, stride_qh, stride_qm, stride_qd = q.stride(0), q.stride(2), q.stride(1), q.stride(3) + stride_kb, stride_kh, stride_kn, stride_kd = k.stride(0), k.stride(2), k.stride(1), k.stride(3) + stride_vb, stride_vh, stride_vn, stride_vd = v.stride(0), v.stride(2), v.stride(1), v.stride(3) + stride_ob, stride_oh, stride_om, stride_od = o.stride(0), o.stride(2), o.stride(1), o.stride(3) + stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() + + # fp8 setup and assertions IS_FP8 = is_fp8(q) if IS_FP8: - FP8_MAX: tl.constexpr = torch.finfo(q.dtype).max + # we already asserted that q, k, v all have the same dtype, so no need to check each one - assert is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8." + FP8_MAX = torch.finfo(q.dtype).max + # Check descale tensors + assert descale_q is not None, "descale_q must be provided when using fp8" + assert descale_k is not None, "descale_k must be provided when using fp8" + assert descale_v is not None, "descale_v must be provided when using fp8" + if is_fp8(o): FP8_OUTPUT = True - assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor for the output." + assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." else: FP8_OUTPUT = False + # o should be fp32 or fp16/bf16 + assert o.dtype in [torch.float16, torch.bfloat16, torch.float32], \ + f"Output tensor o must be fp16, bf16, or fp32 when using fp8, got {o.dtype}" - # Get strides for the kernel stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None @@ -921,65 +1013,23 @@ def attention_prefill_forward_triton_impl( FP8_OUTPUT = False descale_q = descale_k = descale_v = descale_o = None stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = None - - # check flags - IS_VARLEN = layout == "thd" + + # check output dtype matches input dtype when not using fp8 + assert o.dtype == q.dtype, f"Output dtype {o.dtype} must match input dtype {q.dtype} when not using fp8" + + # check features use_sliding_window = window_size_left != -1 or window_size_right!= -1 use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) - is_inference = False if cache_seqlens is None else True - if is_inference: - assert layout == "bshd", f"{layout} layout is not supported with inference. Use bshd layout" - if DEBUG: - print(f"is_inference:", is_inference) - # NOTE: a large bias tensor leads to overflow during pointer arithmetic if (bias is not None): assert (bias.numel() < 2**31) - # get shape and strides - if IS_VARLEN: # thd layout - # shape - total_q, nheads_q, head_size = q.shape - _, nheads_k, _ = k.shape - assert cu_seqlens_q is not None - batch = len(cu_seqlens_q) - 1 - - # softmax_lse is the log of the normalization constant / sum of expoential score(unnormalzied probablities) - softmax_lse = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) - - # strides - stride_qb, stride_qh, stride_qm, stride_qd = 0, q.stride(1), q.stride(0), q.stride(2) - stride_kb, stride_kh, stride_kn, stride_kd = 0, k.stride(1), k.stride(0), k.stride(2) - stride_vb, stride_vh, stride_vn, stride_vd = 0, v.stride(1), v.stride(0), v.stride(2) - stride_ob, stride_oh, stride_om, stride_od = 0, o.stride(1), o.stride(0), o.stride(2) - stride_lse_z, stride_lse_h, stride_lse_m = 0, softmax_lse.stride(0), softmax_lse.stride(1) - else: # bshd layout - # shape - batch, seqlen_q, nheads_q, head_size = q.shape - _, _, nheads_k, _ = k.shape - - # softmax_lse is the log of the normalization constant / sum of expoential score(unnormalzied probablities) - softmax_lse = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) - - # strides - stride_qb, stride_qh, stride_qm, stride_qd = q.stride(0), q.stride(2), q.stride(1), q.stride(3) - stride_kb, stride_kh, stride_kn, stride_kd = k.stride(0), k.stride(2), k.stride(1), k.stride(3) - stride_vb, stride_vh, stride_vn, stride_vd = v.stride(0), v.stride(2), v.stride(1), v.stride(3) - stride_ob, stride_oh, stride_om, stride_od = o.stride(0), o.stride(2), o.stride(1), o.stride(3) - stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() - # Get closest power of 2 over or equal to 32. padded_d_model = 1 << (head_size - 1).bit_length() # Smallest head_dim supported is 16. If smaller, the tile in the # kernel is padded - there is no padding in memory for any dims. padded_d_model = max(padded_d_model, 16) - FLIP_GRID = True - if FLIP_GRID: - grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META['BLOCK_M'])) - else: - grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) - # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing @@ -999,13 +1049,14 @@ def attention_prefill_forward_triton_impl( dropout_mask = None stride_sz, stride_sh, stride_sm, stride_sn = (0, 0, 0, 0) - if bias is not None: stride_bz, stride_bh, stride_bm, stride_bn = (bias.stride(0), bias.stride(1),bias.stride(2), bias.stride(3)) else: stride_bz, stride_bh, stride_bm, stride_bn = (0, 0, 0, 0) + # launch kernel + grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META['BLOCK_M'])) attn_fwd[grid](q, k, v, bias, cache_seqlens, cache_batch_idx, descale_q, descale_k, descale_v, descale_o, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, sm_scale, softmax_lse, o, @@ -1025,6 +1076,6 @@ def attention_prefill_forward_triton_impl( IS_VARLEN=IS_VARLEN, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, - IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT, FLIP_GRID=FLIP_GRID) + IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT) return softmax_lse, sd_mask if return_softmax else None \ No newline at end of file From 855e86a941816c3a1a3831db4def4fc24f4a9951 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 17 Jul 2025 21:00:36 +0000 Subject: [PATCH 07/27] Fix SDMASK bug --- .../flash_attn_triton_amd/fwd_prefill.py | 19 ++++++++----------- .../flash_attn_triton_amd/interface_fa.py | 4 ---- flash_attn/flash_attn_triton_amd/test.py | 4 ---- 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index b646b486ce1..71b2b40458a 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -514,7 +514,7 @@ def compute_masking(seqlen_k, seqlen_q, start_m, use_cuda_graph=True, ) @triton.jit -def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, +def attn_fwd(Q, K, V, bias, Descale_Q, Descale_K, Descale_V, Descale_O, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, @@ -525,7 +525,7 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, NEEDS_SDMASK : tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr): # set params ACCUMULATOR_TYPE = tl.float32 @@ -630,7 +630,7 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, else: alibi_slope = None - if RETURN_SCORES: + if NEEDS_SDMASK: sd_mask_offset = sd_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm sd_mask_ptrs = sd_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn else: @@ -758,7 +758,7 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, l_i_safe = tl.where(invalid_mask, 1.0, l_i) l_recip = 1 / l_i_safe[:, None] else: - # Original code path + invalid_mask = None l_recip = 1 / l_i[:, None] acc = acc * l_recip if ENABLE_DROPOUT: @@ -878,9 +878,6 @@ def attention_prefill_forward_triton_impl( cu_seqlens_k: Optional[torch.Tensor], max_seqlens_q: int, max_seqlens_k: int, - # inference - cache_seqlens: Optional[Union[(int, torch.Tensor)]], - cache_batch_idx: Optional[torch.Tensor], # dropout dropout_p: float, philox_seed: Optional[int], @@ -1034,8 +1031,8 @@ def attention_prefill_forward_triton_impl( # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing # only. This return holds no useful output aside from debugging. - use_dropout = (dropout_p > 0.0) - if use_dropout or return_softmax: + NEEDS_SDMASK = (dropout_p > 0.0) or return_softmax + if NEEDS_SDMASK: sd_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, dtype=torch.float32) if DROPOUT_USE_PYTORCH: @@ -1057,7 +1054,7 @@ def attention_prefill_forward_triton_impl( # launch kernel grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META['BLOCK_M'])) - attn_fwd[grid](q, k, v, bias, cache_seqlens, cache_batch_idx, + attn_fwd[grid](q, k, v, bias, descale_q, descale_k, descale_v, descale_o, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, sm_scale, softmax_lse, o, stride_qb, stride_qh, stride_qm, stride_qd, @@ -1075,7 +1072,7 @@ def attention_prefill_forward_triton_impl( USE_SLIDING_WINDOW=use_sliding_window, WINDOW_SIZE_LEFT=window_size_left, WINDOW_SIZE_RIGHT=window_size_right, IS_VARLEN=IS_VARLEN, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, - USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, + USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, NEEDS_SDMASK=NEEDS_SDMASK, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT) return softmax_lse, sd_mask if return_softmax else None \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 4e9bd006aee..a6b83f64365 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -133,8 +133,6 @@ def fwd(q: torch.Tensor, metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k, - metadata.cache_seqlens, - metadata.cache_batch_idx, metadata.dropout_p, metadata.philox_seed, metadata.philox_offset, @@ -510,8 +508,6 @@ def varlen_fwd( metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k, - metadata.cache_seqlens, - metadata.cache_batch_idx, metadata.dropout_p, metadata.philox_seed, metadata.philox_offset, diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index 9cffda127a6..d57bd0fc745 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -22,10 +22,6 @@ ) from .utils import generate_bshd_kv_packed, generate_bshd_qkv_packed, generate_bshd_tensor, generate_varlen_kv_packed, generate_varlen_qkv_packed, input_helper, arch_supports_fp8, generate_varlen_tensor -from .fwd_ref import attention_forward_pytorch_ref_impl -from .fwd_prefill import attention_prefill_forward_triton_impl -from .bwd_prefill_fused_no_atomics import attention_prefill_backward_triton_split_fused_no_atomics_impl -from .bwd_ref import attention_backward_pytorch_ref_impl DEBUG = False From 33fa8bd63a52c56be72172e22c19360dffe50a38 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Fri, 18 Jul 2025 18:56:42 +0000 Subject: [PATCH 08/27] Log triton, torch and fa version --- flash_attn/flash_attn_triton_amd/test.py | 78 +++++++++++++----------- 1 file changed, 42 insertions(+), 36 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index d57bd0fc745..ed528292cc5 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -7,27 +7,16 @@ import logging import numpy as np from pathlib import Path -from flash_attn import ( - flash_attn_func, - flash_attn_fp8_func, - flash_attn_kvpacked_func, - flash_attn_qkvpacked_func, - flash_attn_qkvpacked_fp8_func, - flash_attn_varlen_func, - flash_attn_varlen_fp8_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, - flash_attn_varlen_qkvpacked_fp8_func, - flash_attn_with_kvcache -) +import triton +import flash_attn from .utils import generate_bshd_kv_packed, generate_bshd_qkv_packed, generate_bshd_tensor, generate_varlen_kv_packed, generate_varlen_qkv_packed, input_helper, arch_supports_fp8, generate_varlen_tensor -DEBUG = False +# debugging -# set print options -# torch.set_printoptions(linewidth=5e5, edgeitems=10, sci_mode=False) -# np.set_printoptions(linewidth=5000, threshold=1e4, suppress=True, precision=4) +logging.basicConfig(level=logging.INFO, format='%(message)s', force=True) +logger = logging.getLogger(__name__) +DEBUG = False # defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose. @@ -179,7 +168,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac do_fp8= do.clone() if is_varlen: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_qkvpacked_fp8_func( + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( qkv_fp8, metadata.cu_seqlens_q, metadata.max_seqlens_q, @@ -192,7 +181,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac return_attn_probs=True, ) else: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_qkvpacked_fp8_func( + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_qkvpacked_fp8_func( qkv_fp8, dropout_p, causal=causal, @@ -211,7 +200,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac do_ref= do.clone() if is_varlen: - out_ref, lse_ref, S_dmask_ref = flash_attn_varlen_qkvpacked_func( + out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_varlen_qkvpacked_func( qkv_ref, metadata.cu_seqlens_q, metadata.max_seqlens_q, @@ -224,7 +213,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac return_attn_probs=True, ) else: - out_ref, lse_ref, S_dmask_ref = flash_attn_qkvpacked_func( + out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_qkvpacked_func( qkv_ref, dropout_p, causal=causal, @@ -292,7 +281,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac do_fp8= do.clone() if is_varlen: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_fp8_func( + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_varlen_fp8_func( q_fp8, k_fp8, v_fp8, @@ -309,7 +298,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac return_attn_probs=True, ) else: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_fp8_func( + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_fp8_func( q_fp8, k_fp8, v_fp8, @@ -335,7 +324,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac do_ref = do.clone() if is_varlen: - out_ref, lse_ref, S_dmask_ref = flash_attn_varlen_func( + out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_varlen_func( q_ref, k_ref, v_ref, @@ -352,7 +341,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac return_attn_probs=True, ) else: - out_ref, lse_ref, S_dmask_ref = flash_attn_func( + out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_func( q_ref, k_ref, v_ref, @@ -463,7 +452,7 @@ def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, if packing == None: # fp8 forward pass if is_varlen: - out, lse, S_dmask = flash_attn_varlen_fp8_func( + out, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( q, k, v, @@ -480,7 +469,7 @@ def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, return_attn_probs=True, ) else: - out, lse, S_dmask = flash_attn_fp8_func( + out, lse, S_dmask = flash_attn.flash_attn_fp8_func( q, k, v, @@ -506,7 +495,7 @@ def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, # fp8 forward pass for qkv-packed input if is_varlen: - out, lse, S_dmask = flash_attn_varlen_qkvpacked_fp8_func( + out, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( qkv, metadata.cu_seqlens_q, metadata.max_seqlens_q, @@ -519,7 +508,7 @@ def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, return_attn_probs=True, ) else: - out, lse, S_dmask = flash_attn_qkvpacked_fp8_func( + out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( qkv, dropout_p, causal=causal, @@ -611,7 +600,7 @@ def test_torch_compile(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD): k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD) v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD) - flash_attn_func_compiled = torch.compile(flash_attn_func) + flash_attn_func_compiled = torch.compile(flash_attn.flash_attn_func) o = flash_attn_func_compiled(q, k, v, causal=True) print(f"Output shape: {o.shape}, dtype: {o.dtype}") o.sum().backward() @@ -630,7 +619,7 @@ def test_torch_compile(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD): k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(BATCH * N_CTX_K, HK, D_HEAD, BATCH) v, _, _ = generate_varlen_tensor(BATCH * N_CTX_K, HK, D_HEAD, BATCH) - flash_attn_varlen_func_compiled = torch.compile(flash_attn_varlen_func) + flash_attn_varlen_func_compiled = torch.compile(flash_attn.flash_attn_varlen_func) o = flash_attn_varlen_func_compiled( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=True @@ -650,7 +639,7 @@ def test_torch_compile(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD): qkv = generate_bshd_qkv_packed(BATCH, N_CTX_Q, HQ, D_HEAD) - flash_attn_qkvpacked_func_compiled = torch.compile(flash_attn_qkvpacked_func) + flash_attn_qkvpacked_func_compiled = torch.compile(flash_attn.flash_attn_qkvpacked_func) o = flash_attn_qkvpacked_func_compiled(qkv, causal=True) print(f"Output shape: {o.shape}, dtype: {o.dtype}") o.sum().backward() @@ -668,7 +657,7 @@ def test_torch_compile(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD): total_q = BATCH * N_CTX_Q qkv, cu_seqlens, max_seqlen = generate_varlen_qkv_packed(total_q, HQ, D_HEAD, BATCH) - flash_attn_varlen_qkvpacked_func_compiled = torch.compile(flash_attn_varlen_qkvpacked_func) + flash_attn_varlen_qkvpacked_func_compiled = torch.compile(flash_attn.flash_attn_varlen_qkvpacked_func) o = flash_attn_varlen_qkvpacked_func_compiled( qkv, cu_seqlens, max_seqlen, causal=True ) @@ -688,7 +677,7 @@ def test_torch_compile(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD): q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD) kv = generate_bshd_kv_packed(BATCH, N_CTX_K, HK, D_HEAD) - flash_attn_kvpacked_func_compiled = torch.compile(flash_attn_kvpacked_func) + flash_attn_kvpacked_func_compiled = torch.compile(flash_attn.flash_attn_kvpacked_func) o = flash_attn_kvpacked_func_compiled(q, kv, causal=True) print(f"Output shape: {o.shape}, dtype: {o.dtype}") o.sum().backward() @@ -706,7 +695,7 @@ def test_torch_compile(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD): q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(BATCH * N_CTX_Q, HQ, D_HEAD, BATCH) kv, cu_seqlens_k, max_seqlen_k = generate_varlen_kv_packed(BATCH * N_CTX_K, HK, D_HEAD, BATCH) - flash_attn_varlen_kvpacked_func_compiled = torch.compile(flash_attn_varlen_kvpacked_func) + flash_attn_varlen_kvpacked_func_compiled = torch.compile(flash_attn.flash_attn_varlen_kvpacked_func) o = flash_attn_varlen_kvpacked_func_compiled( q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=True @@ -743,7 +732,7 @@ def test_torch_compile(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD): v_new = generate_bshd_tensor(BATCH, NEW_SEQLEN, HK, D_HEAD, dtype=torch.float16) # Note: flash_attn_with_kvcache doesn't support backward pass - flash_attn_with_kvcache_compiled = torch.compile(flash_attn_with_kvcache) + flash_attn_with_kvcache_compiled = torch.compile(flash_attn.flash_attn_with_kvcache) # Test with new k,v (append to cache and do attention) with torch.no_grad(): @@ -768,3 +757,20 @@ def test_torch_compile(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD): # final cleanup torch.cuda.empty_cache() clear_compile_cache() + +# log env +if os.environ.get('PYTEST_XDIST_WORKER') in (None, 'gw0'): + logger.info("\n" + "="*80) + logger.info("ENVIRONMENT INFORMATION") + logger.info("="*80) + + # triton + logger.info(f" Triton Version: {triton.__version__}") + + # torch + logger.info(f" PyTorch Version: {torch.__version__}") + + # flash attention + logger.info(f" Flash Attention Version: {flash_attn.__version__}") + + logger.info("="*80 + "\n") From 87b7d06bd7960a218f8447fe07674d12c542ae8f Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Sat, 19 Jul 2025 00:29:39 +0000 Subject: [PATCH 09/27] Fix fp8 import issues --- flash_attn/flash_attn_triton_amd/test.py | 17 +++++++++-------- flash_attn/flash_attn_triton_amd/train.py | 7 ++++--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index ed528292cc5..15fa8f2f07a 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -9,6 +9,7 @@ from pathlib import Path import triton import flash_attn +import flash_attn.flash_attn_triton_amd.fp8 from .utils import generate_bshd_kv_packed, generate_bshd_qkv_packed, generate_bshd_tensor, generate_varlen_kv_packed, generate_varlen_qkv_packed, input_helper, arch_supports_fp8, generate_varlen_tensor @@ -168,7 +169,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac do_fp8= do.clone() if is_varlen: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_qkvpacked_fp8_func( qkv_fp8, metadata.cu_seqlens_q, metadata.max_seqlens_q, @@ -181,7 +182,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac return_attn_probs=True, ) else: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_qkvpacked_fp8_func( + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_qkvpacked_fp8_func( qkv_fp8, dropout_p, causal=causal, @@ -281,7 +282,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac do_fp8= do.clone() if is_varlen: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_varlen_fp8_func( + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_fp8_func( q_fp8, k_fp8, v_fp8, @@ -298,7 +299,7 @@ def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, pac return_attn_probs=True, ) else: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_fp8_func( + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_fp8_func( q_fp8, k_fp8, v_fp8, @@ -452,7 +453,7 @@ def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, if packing == None: # fp8 forward pass if is_varlen: - out, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( + out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_fp8_func( q, k, v, @@ -469,7 +470,7 @@ def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, return_attn_probs=True, ) else: - out, lse, S_dmask = flash_attn.flash_attn_fp8_func( + out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_fp8_func( q, k, v, @@ -495,7 +496,7 @@ def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, # fp8 forward pass for qkv-packed input if is_varlen: - out, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( + out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_qkvpacked_fp8_func( qkv, metadata.cu_seqlens_q, metadata.max_seqlens_q, @@ -508,7 +509,7 @@ def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, return_attn_probs=True, ) else: - out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( + out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_qkvpacked_fp8_func( qkv, dropout_p, causal=causal, diff --git a/flash_attn/flash_attn_triton_amd/train.py b/flash_attn/flash_attn_triton_amd/train.py index fc5f5d0b1bf..8f39f627bfb 100644 --- a/flash_attn/flash_attn_triton_amd/train.py +++ b/flash_attn/flash_attn_triton_amd/train.py @@ -7,7 +7,8 @@ from tqdm import tqdm import matplotlib.pyplot as plt from datasets import load_dataset -from flash_attn import flash_attn_qkvpacked_func, flash_attn_qkvpacked_fp8_func, flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_fp8_func +import flash_attn +import flash_attn.flash_attn_triton_amd.fp8 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"using device: {device}") @@ -40,13 +41,13 @@ def forward(self, x): # use the appropriate flash attention function if self.use_fp8: - context = flash_attn_qkvpacked_fp8_func( + context = flash_attn.flash_attn_triton_amd.fp8.flash_attn_qkvpacked_fp8_func( qkv_packed, dropout_p=self.dropout_p, causal=self.causal ) else: - context = flash_attn_qkvpacked_func( + context = flash_attn.flash_attn_qkvpacked_func( qkv_packed, dropout_p=self.dropout_p, causal=self.causal From 265a78a08162734cbbdae6b1b1401f3076db1839 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Wed, 23 Jul 2025 11:20:37 -0400 Subject: [PATCH 10/27] fix docs (#154) --- flash_attn/flash_attn_triton_amd/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/README.md b/flash_attn/flash_attn_triton_amd/README.md index f3a5db67fc5..87213c1883c 100644 --- a/flash_attn/flash_attn_triton_amd/README.md +++ b/flash_attn/flash_attn_triton_amd/README.md @@ -81,10 +81,10 @@ docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE ``` ###### FP8 -In our fork We have created the following api functions that use fp8 to compute their values. These functions are `flash_attn_fp8_func`, `flash_attn_varlen_fp8_func`, `flash_attn_qkvpacked_fp8_func` and `flash_attn_varlen_qkvpacked_fp8_func`. To use these functions just call them with like the other api functions, the casting will be handled internally. For example +In our fork We have created the following api functions that use fp8 internally to compute their values. These functions are `flash_attn_fp8_func`, `flash_attn_varlen_fp8_func`, `flash_attn_qkvpacked_fp8_func` and `flash_attn_varlen_qkvpacked_fp8_func`. Here is a usage example ``` -from flash_attn import flash_attn_qkvpacked_fp8_func +from flash_attn.flash_attn_triton_amd.fp8 import flash_attn_qkvpacked_fp8_func # forward pass out, lse, S_dmask = flash_attn_qkvpacked_fp8_func( From 175ad1f31ccbdeabf061a4cffda8d85c37923e9d Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Fri, 8 Aug 2025 04:37:57 -0400 Subject: [PATCH 11/27] Sliding Window block classification logic (#155) * add aiter code * remove aiter stuff * sliding window non causal masking works * causal and sliding window block masking * extract common * clean up typo * helper for swa * ignore .amd * fix last block bug --- .gitignore | 2 +- .../flash_attn_triton_amd/fwd_prefill.py | 169 +++++++++++++++--- 2 files changed, 141 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index 9dddf555d43..96d9af07e82 100644 --- a/.gitignore +++ b/.gitignore @@ -53,4 +53,4 @@ training/data # ck modules csrc/composable_kernel csrc/cutlass -.analysis \ No newline at end of file +.amd \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 71b2b40458a..3a2bd56fda4 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -382,26 +382,97 @@ def _attn_fwd_mask(acc, l_i, m_i, @triton.jit -def compute_masking(seqlen_k, seqlen_q, start_m, - IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, - WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - """ - Classify K blocks for attention computation with sliding window support. +def compute_window_bounds(q_start, q_end, diag, seqlen_k, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + IS_CAUSAL: tl.constexpr): + """Calculate the window boundaries for a query block.""" + # Left boundary + if WINDOW_SIZE_LEFT < 0: + left_min = 0 + left_max = 0 + else: + left_min = tl.maximum(0, q_start + diag - WINDOW_SIZE_LEFT) + left_max = tl.maximum(0, q_end + diag - WINDOW_SIZE_LEFT) - Returns: - - n_front_skip_blocks: Blocks completely before the window - - n_front_masked_blocks: Blocks partially overlapping window front - - n_full_blocks: Blocks completely inside the window - - n_back_masked_blocks: Blocks partially overlapping window back - - n_extra_tokens: Padding tokens in last K block + # Right boundary + if IS_CAUSAL: + # Causal cap: col ≤ row + diag + right_min = tl.minimum(seqlen_k - 1, q_start + diag) + right_max = tl.minimum(seqlen_k - 1, q_end + diag) + else: + if WINDOW_SIZE_RIGHT < 0: + right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) + right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) + else: + # Non-causal doesn't have the diagonal constraint + right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) + right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) + + return left_min, left_max, right_min, right_max + +@triton.jit +def classify_window_blocks(left_min, left_max, right_min, right_max, + BLOCK_N: tl.constexpr): + """Classify blocks based on window boundaries.""" + # First and last blocks that have ANY overlap with window + first_block = left_min // BLOCK_N + last_block = right_max // BLOCK_N + + # First block that is FULLY visible for all rows in Q block + full_left_block = left_max // BLOCK_N + (left_max % BLOCK_N != 0) + clipped_left = tl.minimum(full_left_block, last_block + 1) + + # Last block that is FULLY visible for all rows in Q block + last_full_block_candidate = right_min // BLOCK_N + if (last_full_block_candidate + 1) * BLOCK_N - 1 > right_min: + last_full_block_candidate -= 1 + full_right_block = tl.maximum(last_full_block_candidate, clipped_left - 1) + + # Calculate counts + n_front_skip_blocks = first_block + n_front_masked_blocks = tl.maximum(0, clipped_left - first_block) + n_full_blocks = tl.maximum(0, full_right_block - clipped_left + 1) + n_back_masked_blocks = tl.maximum(0, last_block - full_right_block) + + return (n_front_skip_blocks, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks, + clipped_left) # Return clipped_left for padded block handling + +@triton.jit +def handle_padded_last_block(n_extra_tokens, last_block, total_k_blocks, + clipped_left, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks): + """Ensure a padded last K-block is never classified as 'full'. + + We move the padded last block (if visible) into the back-masked bucket. + If it's already back-masked, we do nothing. If it was counted in the + front-masked range, we decrement front-masked; if it was counted as full, + we decrement full. Then we increment back-masked. """ - # Example case - # BLOCK_M = 4, BLOCK_N = 4, seqlen_q = 8, seqlen_k = 10 + padded_last_k = (n_extra_tokens != 0) & (last_block == total_k_blocks - 1) + + if padded_last_k: + # current 'full' range right edge + full_right_block = clipped_left + n_full_blocks - 1 + + # If last_block is already beyond full_right_block, it's already in back-masked → nothing to do + last_already_back_masked = last_block > full_right_block + if not last_already_back_masked: + # If the window starts past last_block, it was counted in front-masked + if clipped_left > last_block: + n_front_masked_blocks = tl.maximum(0, n_front_masked_blocks - 1) + else: + # Otherwise it was counted 'full' → move it out of full + n_full_blocks = tl.maximum(0, n_full_blocks - 1) + # In both cases we need one more back-masked block + n_back_masked_blocks = n_back_masked_blocks + 1 - # Total K blocks in the key sequence - total_k_blocks = tl.cdiv(seqlen_k, BLOCK_N) + return n_front_masked_blocks, n_full_blocks, n_back_masked_blocks +@triton.jit +def compute_padding_info(seqlen_k, BLOCK_N: tl.constexpr): + """Calculate padding information for the last K block.""" # check if we will need to do masking due either BLOCK_N being bigger than seqlen_k or seqlen_k not being a factor of BLOCK_N # n_extra_tokens = 10 % 4 = 2 # This means the last K block has 2 valid tokens and 2 padding positions @@ -415,15 +486,60 @@ def compute_masking(seqlen_k, seqlen_q, start_m, elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N else: - n_extra_tokens = 0 + n_extra_tokens = 0 + return n_extra_tokens + +@triton.jit +def compute_block_masking(seqlen_k, seqlen_q, start_m, + IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + """ + Classify K blocks for attention computation with sliding window support. + + Returns: + - n_front_skip_blocks: Blocks completely before the window + - n_front_masked_blocks: Blocks partially overlapping window front + - n_full_blocks: Blocks completely inside the window + - n_back_masked_blocks: Blocks partially overlapping window back + - n_extra_tokens: Padding tokens in last K block + """ + + # common + q_start = start_m * BLOCK_M + q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) + diag = seqlen_k - seqlen_q + total_k_blocks = tl.cdiv(seqlen_k, BLOCK_N) + n_extra_tokens = compute_padding_info(seqlen_k, BLOCK_N) if USE_SLIDING_WINDOW: - # TODO: Optimize by computing which blocks can be fully skipped - # For now, process all blocks with the mask function - if IS_CAUSAL: - return 0, 0, 0, total_k_blocks, n_extra_tokens - else: - return 0, 0, 0, total_k_blocks, n_extra_tokens + # get window bounds + left_min, left_max, right_min, right_max = compute_window_bounds( + q_start, q_end, diag, seqlen_k, + WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, IS_CAUSAL + ) + + # window vanishes → early exit + if right_max < left_min: + return 0, 0, 0, 0, n_extra_tokens + + # classify blocks + (n_front_skip_blocks, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks, + clipped_left) = classify_window_blocks( + left_min, left_max, right_min, right_max, BLOCK_N + ) + + # handle padded last block if needed + if n_extra_tokens != 0: + last_block = right_max // BLOCK_N + n_front_masked_blocks, n_full_blocks, n_back_masked_blocks = handle_padded_last_block( + n_extra_tokens, last_block, total_k_blocks, + clipped_left, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks + ) + return (n_front_skip_blocks, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks, n_extra_tokens) else: if IS_CAUSAL: # ========== CAUSAL MODE: Classify K Blocks ========== @@ -444,11 +560,6 @@ def compute_masking(seqlen_k, seqlen_q, start_m, # 1. figure out, in tokens, the right-most K position # this Q-block may attend to # ------------------------------------------------------------ - q_start = start_m * BLOCK_M - q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) - - # causal diagonal offset between the two streams - diag = seqlen_k - seqlen_q # 0 when |Q| == |K| k_max_token = q_end + diag # last visible K index # this Q-block is entirely above the diagonal ⇒ nothing to do @@ -575,7 +686,7 @@ def attn_fwd(Q, K, V, bias, # figure out masking pattern - n_front_skip_blocks, n_front_masked_blocks, n_full_blocks, n_back_masked_blocks, n_extra_tokens = compute_masking( + n_front_skip_blocks, n_front_masked_blocks, n_full_blocks, n_back_masked_blocks, n_extra_tokens = compute_block_masking( seqlen_k, seqlen_q, start_m, IS_CAUSAL, USE_SLIDING_WINDOW, WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, BLOCK_M, BLOCK_N ) From 192b7bc4c83584b2b32b22a237fe3004c9ebfaec Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Thu, 18 Sep 2025 10:26:55 -0400 Subject: [PATCH 12/27] Enable FA V3 (#157) * Compress PA work narrow pa test ref works on most cases inplace ref with new_kv inplace paged attention add pa ref save pa basic paged works save fix swa + causal in pa. Also new_kv only on pa path passing build fa v3 import interface from fa v3 copy fa tests use v3 api clean up rename to match old test support different head sizes remove fp8 basisc passing v3 cases test_flash_attn_varlen_output v3 working isolate bad case for kvcache case passing save use decode is seqused/ cacheseql is given use decode if not varlen basci kvcache v3 working kvcache enable more cases detect kvcache case if seqused_q is non and sequese_k is not None skip failing test find fp8 failing case mha fp8 works fix fp8 MQA/GQA bug clean up more clean up clean up more don't need fp8 dead code remove train code with fp8 stuff fp8 working in kvcache paged + fp8 seems to be working new_kv allowed * clean up * skip hopper race test * clean up more * fix paged + alibi * similar inner paged api * unify _attn_fwd_inner --- .github/workflows/amd_tests.yml | 65 - README.md | 12 +- flash_attn/flash_attn_triton_amd/.gitignore | 2 - flash_attn/flash_attn_triton_amd/Dockerfile | 17 - flash_attn/flash_attn_triton_amd/README.md | 113 -- .../bwd_prefill_fused_atomics.py | 8 +- .../bwd_prefill_fused_no_atomics.py | 299 +++-- .../bwd_prefill_split.py | 10 +- flash_attn/flash_attn_triton_amd/fp8.py | 716 ---------- .../flash_attn_triton_amd/fwd_decode.py | 576 +++++--- .../flash_attn_triton_amd/fwd_prefill.py | 350 +++-- flash_attn/flash_attn_triton_amd/fwd_ref.py | 336 ++++- .../flash_attn_triton_amd/interface_fa.py | 31 +- .../flash_attn_triton_amd/interface_fa_v3.py | 660 +++++++++ flash_attn/flash_attn_triton_amd/test.py | 777 ----------- flash_attn/flash_attn_triton_amd/train.py | 404 ------ flash_attn/flash_attn_triton_amd/utils.py | 91 +- hopper/flash_attn_interface.py | 26 +- hopper/setup.py | 8 +- hopper/test_flash_attn_triton_amd.py | 1174 +++++++++++++++++ tests/test_flash_attn_triton_amd.py | 2 +- 21 files changed, 3021 insertions(+), 2656 deletions(-) delete mode 100644 .github/workflows/amd_tests.yml delete mode 100644 flash_attn/flash_attn_triton_amd/.gitignore delete mode 100644 flash_attn/flash_attn_triton_amd/Dockerfile delete mode 100644 flash_attn/flash_attn_triton_amd/README.md mode change 100644 => 100755 flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py mode change 100644 => 100755 flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py mode change 100644 => 100755 flash_attn/flash_attn_triton_amd/bwd_prefill_split.py delete mode 100644 flash_attn/flash_attn_triton_amd/fp8.py mode change 100644 => 100755 flash_attn/flash_attn_triton_amd/fwd_decode.py mode change 100644 => 100755 flash_attn/flash_attn_triton_amd/fwd_prefill.py mode change 100644 => 100755 flash_attn/flash_attn_triton_amd/fwd_ref.py create mode 100755 flash_attn/flash_attn_triton_amd/interface_fa_v3.py delete mode 100644 flash_attn/flash_attn_triton_amd/test.py delete mode 100644 flash_attn/flash_attn_triton_amd/train.py mode change 100644 => 100755 hopper/flash_attn_interface.py mode change 100644 => 100755 hopper/setup.py create mode 100755 hopper/test_flash_attn_triton_amd.py diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml deleted file mode 100644 index 2e3f061c78d..00000000000 --- a/.github/workflows/amd_tests.yml +++ /dev/null @@ -1,65 +0,0 @@ -name: AMD Perf Kernel Tests - -on: - workflow_dispatch: - pull_request: - branches: [main_perf] - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - Integration-Tests-AMD: - runs-on: ${{ matrix.runner }} - strategy: - matrix: - runner: [linux-mi300-gpu-1] - fail-fast: false # disables failing the entire job when one matrix entry fails - timeout-minutes: 720 # self hosted runners can run jobs for longer than the default of 360 minutes - container: - image: rocm/pytorch:latest - options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --shm-size 16G --group-add video --user root - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Show Device Info - run: | - rocminfo | grep gfx - - - name: Uninstall Triton - run: | - pip uninstall -y triton - rm -rf ~/.triton - rm -rf ./triton/python/build - - - name: Install Triton - run: | - pip install triton==3.3.0 - - - name: Show Triton version - run: | - pip show triton - - - name: Build - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install - - - name: Install dependencies for bench and misc - run: | - pip install matplotlib pandas tabulate - - - name: AMD Internal Tests - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest flash_attn/flash_attn_triton_amd/test.py - - - name: Flash Attention Tests - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest -n 8 tests/test_flash_attn_triton_amd.py - - - name: AMD Bench - run: | - python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_func - python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_varlen_func - python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_with_kvcache diff --git a/README.md b/README.md index 65a4154da4a..85c05a9ab68 100755 --- a/README.md +++ b/README.md @@ -183,15 +183,17 @@ WORKDIR /workspace # install triton RUN pip install triton==3.3.0 -# install flash attention -ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - -RUN git clone https://github.com/ROCm/flash-attention.git &&\ +# build flash attention with triton backend +RUN git clone https://github.com/Dao-AILab/flash-attention &&\ cd flash-attention &&\ - python setup.py install + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install # set working dir WORKDIR /workspace/flash-attention + +# set env variable to use triton backend +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + ``` To build the docker file diff --git a/flash_attn/flash_attn_triton_amd/.gitignore b/flash_attn/flash_attn_triton_amd/.gitignore deleted file mode 100644 index 21538fc4e4a..00000000000 --- a/flash_attn/flash_attn_triton_amd/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -bwd_prefill_fused.py -bwd_prefill_onekernel.py \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/Dockerfile b/flash_attn/flash_attn_triton_amd/Dockerfile deleted file mode 100644 index 8df939a0886..00000000000 --- a/flash_attn/flash_attn_triton_amd/Dockerfile +++ /dev/null @@ -1,17 +0,0 @@ -FROM rocm/pytorch:latest - -WORKDIR /workspace - -# install triton -RUN pip install triton==3.3.0 - -# install flash attention -ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - -RUN git clone https://github.com/ROCm/flash-attention.git &&\ - cd flash-attention &&\ - git checkout main_perf &&\ - python setup.py install - -# set working dir -WORKDIR /workspace/flash-attention \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/README.md b/flash_attn/flash_attn_triton_amd/README.md deleted file mode 100644 index 87213c1883c..00000000000 --- a/flash_attn/flash_attn_triton_amd/README.md +++ /dev/null @@ -1,113 +0,0 @@ -Flash Attention Triton Kernel -=============== - -#### Introduction -The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress. - -It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes. - -These features are supported in Fwd and Bwd -1) Fwd and Bwd with causal masking -2) Variable sequence lengths -3) Arbitrary Q and KV sequence lengths -4) Arbitrary head sizes -5) Multi and grouped query attention -6) Dropout -7) Rotary embeddings -8) ALiBi - -We are working on the following things -1) Paged Attention -2) Sliding Window -3) FP8 -4) Performance Improvements - -##### Getting Started -To get started with the triton backend for AMD, follow the steps below. - -First install the recommended Triton version - -``` -pip install triton==3.3.0 -``` -Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. - -``` -cd flash-attention -git checkout main_perf -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install -``` - -To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing. -``` -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py -``` - -You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"` -``` -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE -``` - -###### Docker -You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image. -``` -FROM rocm/pytorch:latest - -WORKDIR /workspace - -# install triton -RUN pip install triton==3.3.0 - -# install flash attention -ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - -RUN git clone https://github.com/ROCm/flash-attention.git &&\ - cd flash-attention &&\ - git checkout main_perf &&\ - python setup.py install - -# set working dir -WORKDIR /workspace/flash-attention -``` - -To build the docker file -``` -docker build -t fa_triton . -``` - -To run the docker image -``` -docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton -``` - -###### FP8 -In our fork We have created the following api functions that use fp8 internally to compute their values. These functions are `flash_attn_fp8_func`, `flash_attn_varlen_fp8_func`, `flash_attn_qkvpacked_fp8_func` and `flash_attn_varlen_qkvpacked_fp8_func`. Here is a usage example - -``` -from flash_attn.flash_attn_triton_amd.fp8 import flash_attn_qkvpacked_fp8_func - -# forward pass -out, lse, S_dmask = flash_attn_qkvpacked_fp8_func( - qkv, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - -# backward pass -do = torch.randn_like(out) -dqkv = torch.autograd.grad(out, (qkv), do) -``` - -You can use the other api functions in a similar way. - - - -##### Credits -AMD Triton kernels team - -OpenAI kernel team diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py old mode 100644 new mode 100755 index e969a3770b8..51e53daedc2 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py @@ -1,7 +1,7 @@ import torch import triton import triton.language as tl -from flash_attn.flash_attn_triton_amd.utils import compute_fp8_scaling_factors +from flash_attn.flash_attn_triton_amd.utils import compute_fp8_scaling_factors, DEBUG, is_fp8 from typing import Optional, Tuple @@ -1503,11 +1503,17 @@ def attention_prefill_backward_triton_fused_atomics_impl( descale_v: Optional[torch.Tensor] = None, descale_do: Optional[torch.Tensor] = None, fused: bool = False, + # seqused for FA v3 (currently ignored in this implementation) + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, ): IS_FP8 = is_fp8(q) if IS_FP8: FP8_MAX = torch.finfo(q.dtype).max descale_strides = (descale_q.stride(0),descale_k.stride(0),descale_v.stride(0),descale_do.stride(0) ) + + if DEBUG: + print(f"FP8 path triggered in bwd_prefill_fused_atomics.py") else: FP8_MAX = None stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_do_z = None diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py old mode 100644 new mode 100755 index 5b2f8858d11..0d3b3a6fdf4 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py @@ -3,10 +3,9 @@ import triton # type: ignore import triton.language as tl # type: ignore from typing import Literal, Optional -from .utils import AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, \ +from .utils import AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, DEBUG, compute_fp8_scaling_factors, \ create_dropout_mask, create_dropout_mask_varlen, is_cdna, is_fp8, is_rdna, round_multiple -DEBUG= False # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) @@ -29,7 +28,7 @@ def get_autotune_configs(): ] preprocess_autotune_keys = [ "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] causal_autotune_configs = [ triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config @@ -39,7 +38,7 @@ def get_autotune_configs(): ] causal_autotune_keys = [ "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] noncausal_autotune_configs = [ triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config @@ -49,7 +48,7 @@ def get_autotune_configs(): ] noncausal_autotune_keys = [ "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) @@ -72,21 +71,21 @@ def get_autotune_configs(): ] preprocess_autotune_keys = [ "max_seqlen_q", - "ACTUAL_HEAD_DIM", "IS_VARLEN", + "ACTUAL_HEAD_DIM_V", "IS_VARLEN", ] causal_autotune_configs = [ triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), ] causal_autotune_keys = [ "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] noncausal_autotune_configs = [ triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), ] noncausal_autotune_keys = [ "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) @@ -117,8 +116,8 @@ def _bwd_preprocess( cu_seqlens_q, max_seqlen_q, Descale_do, PRE_BLOCK: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, IS_VARLEN: tl.constexpr, IS_FP8: tl.constexpr ): @@ -136,7 +135,7 @@ def _bwd_preprocess( # Compute offsets offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) - offs_d = tl.arange(0, HEAD_DIM) + offs_d = tl.arange(0, HEAD_DIM_V) # pointer offsets for O & DO off_o = ( bid * stride_ob + hid * stride_oh @@ -152,9 +151,9 @@ def _bwd_preprocess( # create masks mask_m = offs_m < seqlen_q mask_md = mask_m[:, None] - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) + if PADDED_HEAD_V: + mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM_V # load o = tl.load(O + off_o, mask=mask_md, other=0.0) do = tl.load(DO + off_do, mask=mask_md, other=0.0) @@ -185,8 +184,10 @@ def _bwd_dkdv_inner( stride_lse_m, stride_delta_m, BLOCK_M: tl.constexpr, # 16 BLOCK_N: tl.constexpr, # 128 - HEAD_DIM: tl.constexpr, # - ACTUAL_HEAD_DIM: tl.constexpr, # + HEAD_DIM_QK: tl.constexpr, # + HEAD_DIM_V: tl.constexpr, # + ACTUAL_HEAD_DIM_QK: tl.constexpr, # + ACTUAL_HEAD_DIM_V: tl.constexpr, # dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k @@ -203,18 +204,20 @@ def _bwd_dkdv_inner( DEBUG_TRITON_DETAIL: tl.constexpr, ): # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) delta_qk = seqlen_q - seqlen_k offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) - offs_k = tl.arange(0, HEAD_DIM) + offs_k_qk = tl.arange(0, HEAD_DIM_QK) + offs_k_v = tl.arange(0, HEAD_DIM_V) # mask to make sure not OOB of seqlen_q mask_n = offs_n < seqlen_k # Q and DO are (seqlen_q, head_dim) - # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q - qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk - # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed - do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM_QK, 1), transpose of q + qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k_qk[:, None] * stride_qk + # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM_V), NOT transposed + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k_v[None, :] * stride_dok # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. tl.static_assert(BLOCK_N % BLOCK_M == 0) curr_m = start_m @@ -231,9 +234,10 @@ def _bwd_dkdv_inner( mask_qT = mask_m[None, :] mask_do = mask_m[:, None] mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) - if PADDED_HEAD: - mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM - mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM + if PADDED_HEAD_QK: + mask_qT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK + if PADDED_HEAD_V: + mask_do &= offs_k_v[None, :] < ACTUAL_HEAD_DIM_V qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) # generate dropout mask if ENABLE_DROPOUT: @@ -341,8 +345,10 @@ def _bwd_dq_inner( seqlen_q, seqlen_k, # BLOCK_M2: tl.constexpr, # BLOCK_N2: tl.constexpr, # - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, # + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, # dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, # Filled in by the wrapper. @@ -358,17 +364,19 @@ def _bwd_dq_inner( DEBUG_TRITON_DETAIL: tl.constexpr, ): # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) delta_qk = seqlen_q - seqlen_k offs_m = start_m + tl.arange(0, BLOCK_M2) offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, HEAD_DIM) + offs_k_qk = tl.arange(0, HEAD_DIM_QK) + offs_k_v = tl.arange(0, HEAD_DIM_V) # mask to make sure not OOB of seqlen_q mask_m = offs_m < seqlen_q - kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk - vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k_qk[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k_v[:, None] * stride_vk # D (= delta) is pre-divided by ds_scale. Di = tl.load(Delta + offs_m * stride_delta_m, mask=mask_m, other=0.0) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. @@ -387,12 +395,15 @@ def _bwd_dq_inner( if DEBUG_TRITON_DETAIL: print(f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}") # noqa: E701 if DEBUG_TRITON_DETAIL: print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 mask_kT = mask_n[None, :] + mask_vT = mask_n[None, :] mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) - if PADDED_HEAD: - mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM + if PADDED_HEAD_QK: + mask_kT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK + if PADDED_HEAD_V: + mask_vT &= offs_k_v[:, None] < ACTUAL_HEAD_DIM_V kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) - vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_vT, other=0.0) if ENABLE_DROPOUT: # NOTE: dropout is transposed because it is used to mask pT @@ -477,6 +488,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_az, stride_ah, HQ, HK, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Add seqused parameters max_seqlen_q, max_seqlen_k, Dropout_mask, dropout_p, philox_seed, philox_offset_base, Alibi_slopes, @@ -486,8 +498,10 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -495,6 +509,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr, + USE_SEQUSED: tl.constexpr, # Add flag for seqused DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, ): @@ -514,21 +529,31 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba q_end = tl.load(cu_seqlens_q + bid + 1) k_start = tl.load(cu_seqlens_k + bid) k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start + seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) + actual_seqlen_k = tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start + seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) + else: + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start delta_qk = seqlen_q - seqlen_k if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701 - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - offs_d = tl.arange(0, HEAD_DIM) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) + offs_d_qk = tl.arange(0, HEAD_DIM_QK) + offs_d_v = tl.arange(0, HEAD_DIM_V) GROUP_SIZE: tl.constexpr = HQ // HK # align the delta_qk start_n = pid * BLOCK_N1 if start_n < seqlen_k: # This section does dk and dv - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) # q > k: diretcly skip all the way until the start of causal block start_delta_q_gt_k = delta_qk @@ -548,17 +573,21 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba offs_n = start_n + tl.arange(0, BLOCK_N1) # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_kv &= mask_d[None, :] + mask_k = offs_n[:, None] < seqlen_k + mask_v = offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_k &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_v &= mask_d_v[None, :] # K/V tensors not changed for the group - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d_qk[None, :] * stride_kd + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d_v[None, :] * stride_vd # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + k = tl.load(K + adj_k, mask=mask_k, other=0.0) + v = tl.load(V + adj_v, mask=mask_v, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. # hqid = hkid for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): @@ -628,7 +657,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_dropoutm, stride_dropoutn, # strides for dropout stride_lse_m, stride_delta_m, MASK_BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k @@ -658,7 +687,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_dropoutm, stride_dropoutn, # strides for dropout stride_lse_m, stride_delta_m, BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k @@ -676,13 +705,13 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba # end of GQA/MQA of dkdv # Write back dV adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn - offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd - tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) + offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) # write back dk adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn - offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd + offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd dk *= sm_scale - tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) # This part does dq start_m = pid * BLOCK_M2 @@ -696,11 +725,15 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba offs_m = start_m + tl.arange(0, BLOCK_M2) # Mask for loading K and V mask_q = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_q &= mask_d[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + mask_do = offs_m[:, None] < seqlen_q + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_q &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_do &= mask_d_v[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod # NOTE: don't assume that the strides for k and v are the same! K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn @@ -739,7 +772,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba dropout_offset = \ Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) m = m[:, None] @@ -757,7 +790,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba else: descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) dq = _bwd_dq_inner( dq, q, K, V, do, m, Delta_ptr, sm_scale, @@ -767,7 +800,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_delta_m, seqlen_q, seqlen_k, BLOCK_M2, MASK_BLOCK_N2, - HEAD_DIM, ACTUAL_HEAD_DIM, + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, start_m, start_n, end_n, num_steps, @@ -794,7 +827,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_delta_m, seqlen_q, seqlen_k, BLOCK_M2, BLOCK_N2, - HEAD_DIM, ACTUAL_HEAD_DIM, + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, start_m, start_n, end_n, num_steps, @@ -810,7 +843,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba ) # Write back dQ. adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) # end of GQA/MQA of dq @@ -838,6 +871,7 @@ def bwd_kernel_noncausal( stride_az, stride_ah, HQ, HK, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Add seqused parameters max_seqlen_q, max_seqlen_k, Dropout_mask, dropout_p, philox_seed, philox_offset_base, Alibi_slopes, @@ -847,8 +881,10 @@ def bwd_kernel_noncausal( BLOCK_M2: tl.constexpr, # 128 BLOCK_N2: tl.constexpr, # 32 BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -856,6 +892,7 @@ def bwd_kernel_noncausal( IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr, + USE_SEQUSED: tl.constexpr, # Add flag for seqused DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, ): @@ -875,31 +912,45 @@ def bwd_kernel_noncausal( q_end = tl.load(cu_seqlens_q + bid + 1) k_start = tl.load(cu_seqlens_k + bid) k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start + seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) + actual_seqlen_k = tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start + seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) + else: + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - offs_d = tl.arange(0, HEAD_DIM) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) + offs_d_qk = tl.arange(0, HEAD_DIM_QK) + offs_d_v = tl.arange(0, HEAD_DIM_V) GROUP_SIZE: tl.constexpr = HQ // HK start_n = pid * BLOCK_N1 if start_n < seqlen_k: - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) offs_n = start_n + tl.arange(0, BLOCK_N1) # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_kv &= mask_d[None, :] + mask_k = offs_n[:, None] < seqlen_k + mask_v = offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_k &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_v &= mask_d_v[None, :] # NOTE: don't assume that the strides for k and v are the same! # K/V tensors not changed for the group - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d_qk[None, :] * stride_kd + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d_v[None, :] * stride_vd # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + k = tl.load(K + adj_k, mask=mask_k, other=0.0) + v = tl.load(V + adj_v, mask=mask_v, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): # offset input and output tensor by batch and Q/K heads @@ -948,7 +999,7 @@ def bwd_kernel_noncausal( stride_lse_m, stride_delta_m, BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, # alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k @@ -966,13 +1017,13 @@ def bwd_kernel_noncausal( # Write back dV adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn - offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd - tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) + offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) # write back dk adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn - offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd + offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd dk *= sm_scale - tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) # THIS PART DOES DQ start_m = pid * BLOCK_M2 @@ -980,11 +1031,15 @@ def bwd_kernel_noncausal( offs_m = start_m + tl.arange(0, BLOCK_M2) # Mask for loading K and V mask_q = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_q &= mask_d[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + mask_do = offs_m[:, None] < seqlen_q + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_q &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_do &= mask_d_v[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn # If MQA / GQA, set the K and V head offsets appropriately. @@ -1016,7 +1071,7 @@ def bwd_kernel_noncausal( Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) m = m[:, None] @@ -1034,7 +1089,7 @@ def bwd_kernel_noncausal( end_n = seqlen_k num_steps = tl.cdiv(seqlen_k, BLOCK_N2) - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) dq = _bwd_dq_inner( dq, q, K, V, do, m, Delta_ptr, sm_scale, @@ -1044,7 +1099,7 @@ def bwd_kernel_noncausal( stride_delta_m, seqlen_q, seqlen_k, BLOCK_M2, BLOCK_N2, - HEAD_DIM, ACTUAL_HEAD_DIM, + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, start_m, start_n, end_n, num_steps, @@ -1060,7 +1115,7 @@ def bwd_kernel_noncausal( ) # Write back dQ. adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) @@ -1106,6 +1161,9 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( descale_dq: Optional[torch.Tensor], descale_dk: Optional[torch.Tensor], descale_dv: Optional[torch.Tensor], + # seqused for FA v3 + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, ): # get params, strides and shape IS_VARLEN = layout == "thd" @@ -1135,13 +1193,13 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( assert max_seqlen_k is not None, "max_seqlen_k must be provided for varlen layout" # assert head dimensions - assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" assert nheads_lse == nheads_q, f"softmax_lse heads {nheads_lse} != q heads {nheads_q}" # assert output shapes - assert o.shape == (total_seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_q)}" + assert o.shape == (total_seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" @@ -1157,7 +1215,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( # set vars batch = len(cu_seqlens_q) - 1 - head_size = head_size_q + head_size_qk = head_size_q # strides stride_qb, stride_qm, stride_qh, stride_qd = 0, q.stride(0), q.stride(1), q.stride(2) @@ -1180,7 +1238,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( assert batch_q == batch_k == batch_v, f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" # assert head dimensions - assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" @@ -1188,7 +1246,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( assert seqlen_k == seqlen_v, f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" # assert output shapes - assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected" + assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected" assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" @@ -1199,7 +1257,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( # set vars batch = batch_q - head_size = head_size_q + head_size_qk = head_size_q max_seqlen_q = seqlen_q max_seqlen_k = seqlen_k @@ -1233,6 +1291,9 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None + + if DEBUG: + print(f"FP8 path triggered in bwd_prefill_fused_no_atomics.py (FP8_OUTPUT={FP8_OUTPUT})") else: FP8_MAX = None FP8_OUTPUT = False @@ -1242,10 +1303,14 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) # get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() - padded_d_model = max(padded_d_model, 32) - HEAD_DIM = padded_d_model - ACTUAL_HEAD_DIM = head_size + padded_d_model_qk = 1 << (head_size_qk - 1).bit_length() + padded_d_model_qk = max(padded_d_model_qk, 32) + padded_d_model_v = 1 << (head_size_v - 1).bit_length() + padded_d_model_v = max(padded_d_model_v, 32) + HEAD_DIM_QK = padded_d_model_qk + HEAD_DIM_V = padded_d_model_v + ACTUAL_HEAD_DIM_QK = head_size_qk + ACTUAL_HEAD_DIM_V = head_size_v # init delta if OLD_LSE: @@ -1280,13 +1345,13 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_descale_do_z, cu_seqlens_q, max_seqlen_q, descale_do, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, IS_VARLEN=IS_VARLEN, IS_FP8=IS_FP8 ) - if DEBUG: + if False: print("delta:", delta, delta.shape) # dropout mask tensor for debugging. We dump the dropout mask created in @@ -1337,12 +1402,15 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Pass seqused tensors max_seqlen_q, max_seqlen_k, dropout_mask, dropout_p, philox_seed, philox_offset, alibi_slopes, descale_q, descale_k, descale_v, descale_do, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, USE_ALIBI=use_alibi, @@ -1350,6 +1418,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT, + USE_SEQUSED=(seqused_q is not None or seqused_k is not None), # Add flag for seqused DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) @@ -1371,12 +1440,15 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Pass seqused tensors max_seqlen_q, max_seqlen_k, dropout_mask, dropout_p, philox_seed, philox_offset, alibi_slopes, descale_q, descale_k, descale_v, descale_do, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, USE_ALIBI=use_alibi, @@ -1384,6 +1456,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT, + USE_SEQUSED=(seqused_q is not None or seqused_k is not None), # Add flag for seqused DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py old mode 100644 new mode 100755 index 2728dca7349..cfff39bf8bf --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py @@ -2,11 +2,9 @@ import triton # type: ignore import triton.language as tl # type: ignore from typing import Literal, Optional -from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, get_shapes_from_layout, \ +from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, DEBUG, compute_fp8_scaling_factors, get_shapes_from_layout, \ get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_fp8 -DEBUG = False - # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) @@ -1092,6 +1090,9 @@ def attention_prefill_backward_triton_split_impl( descale_dq: Optional[torch.Tensor], descale_dk: Optional[torch.Tensor], descale_dv: Optional[torch.Tensor], + # seqused for FA v3 (currently ignored in this implementation) + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, ): # debug DEBUG_TRITON: bool = False @@ -1117,6 +1118,9 @@ def attention_prefill_backward_triton_split_impl( stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None + + if DEBUG: + print(f"FP8 path triggered in bwd_prefill_split.py (FP8_OUTPUT={FP8_OUTPUT})") else: FP8_MAX = None FP8_OUTPUT = False diff --git a/flash_attn/flash_attn_triton_amd/fp8.py b/flash_attn/flash_attn_triton_amd/fp8.py deleted file mode 100644 index df79c7926b2..00000000000 --- a/flash_attn/flash_attn_triton_amd/fp8.py +++ /dev/null @@ -1,716 +0,0 @@ -from typing import Optional, Sequence, Tuple, Union -import torch -import torch.nn as nn -from .utils import cast_to_fp8, is_fp8 -from . import interface_fa as flash_attn_gpu - - -def maybe_contiguous(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - -class FlashAttnFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "bshd") - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "bshd") - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "bshd") - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - alibi_slopes, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] # NOTE: this used to be out_padded. It might cause issue doing an empty - - # check output type - assert out.dtype == q.dtype, "Input and output type must match otherwise there will be implicit casting by autograd" - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dq, descale_dq = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - dk, descale_dk = torch.zeros_like(k_fp8), torch.zeros_like(descale_k) - dv, descale_dv = torch.zeros_like(v_fp8), torch.zeros_like(descale_v) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "bshd") - dq, descale_dq = torch.zeros_like(q_fp8, dtype=torch.float32), None - dk, descale_dk = torch.zeros_like(k_fp8, dtype=torch.float32), None - dv, descale_dv = torch.zeros_like(v_fp8, dtype=torch.float32), None - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dq, - dk, - dv, - ctx.alibi_slopes, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, # gen_ - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - -def flash_attn_fp8_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None -): - return FlashAttnFP8Func.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - descale_q, - descale_k, - descale_v, - descale_do - ) - -class FlashAttnVarlenFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - block_table, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_q, max_seqlen=max_seqlen_q) - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_k, max_seqlen=max_seqlen_k) - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_k, max_seqlen=max_seqlen_k) - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - cu_seqlens_q, - cu_seqlens_k, - None, - None, - block_table, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - False, - causal, - window_size[0], - window_size[1], - softcap, - return_softmax, - None, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] # NOTE: this used to be out_padded. It might cause issue doing an empty - - # check output type - assert out.dtype == q.dtype, "Input and output type must match otherwise there will be implicit casting by autograd" - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout_padded): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dq, descale_dq = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - dk, descale_dk = torch.zeros_like(k_fp8), torch.zeros_like(descale_k) - dv, descale_dv = torch.zeros_like(v_fp8), torch.zeros_like(descale_v) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_q, max_seqlen=ctx.max_seqlen_q) - dq, descale_dq = torch.zeros_like(q_fp8, dtype=torch.float32), None - dk, descale_dk = torch.zeros_like(k_fp8, dtype=torch.float32), None - dv, descale_dv = torch.zeros_like(v_fp8, dtype=torch.float32), None - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.varlen_bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - ctx.alibi_slopes, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.softmax_scale, - False, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_varlen_fp8_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - block_table=None -): - return FlashAttnVarlenFP8Func.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - block_table, - torch.is_grad_enabled() - ) - -class FlashAttnQKVPackedFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach() - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "bshd") - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "bshd") - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "bshd") - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - alibi_slopes, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o, - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - qkv_shape = q_fp8.shape[:-2] + (3, *q_fp8.shape[-2:]) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dqkv, descale_dqkv = torch.zeros(qkv_shape, device=q_fp8.device), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "bshd") - dqkv, descale_dqkv = torch.zeros(qkv_shape, dtype=torch.float32, device=q_fp8.device), None - - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dqkv[:, :, 0], - dqkv[:, :, 1], - dqkv[:, :, 2], - ctx.alibi_slopes, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, # gen_ - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - None, - None, - None, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None - - -def flash_attn_qkvpacked_fp8_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # <=0.0 means deactivate - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - return FlashAttnQKVPackedFP8Func.apply( - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -class FlashAttnVarlenQKVPackedFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach() - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - cu_seqlens, - cu_seqlens, - None, - None, - None, - alibi_slopes, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - False, - causal, - window_size[0], - window_size[1], - softcap, - return_softmax, - None, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.max_seqlen = max_seqlen - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - qkv_shape = q_fp8.shape[:-2] + (3, *q_fp8.shape[-2:]) - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout_padded): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dqkv, descale_dqkv = torch.zeros(qkv_shape, device=q_fp8.device), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=ctx.max_seqlen) - dqkv, descale_dqkv = torch.zeros(qkv_shape, dtype=torch.float32, device=q_fp8.device), None - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.varlen_bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dqkv[:, 0], - dqkv[:, 1], - dqkv[:, 2], - cu_seqlens, - cu_seqlens, - ctx.alibi_slopes, - ctx.max_seqlen, - ctx.max_seqlen, - ctx.dropout_p, - ctx.softmax_scale, - False, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - None, - None, - None, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_varlen_qkvpacked_fp8_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - return FlashAttnVarlenQKVPackedFP8Func.apply( - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py old mode 100644 new mode 100755 index d3f7f9c32b9..327967cebf7 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -2,7 +2,7 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import AUTOTUNE, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna +from .utils import AUTOTUNE, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna, is_fp8 DEBUG = False @@ -59,6 +59,118 @@ def get_autotune_configs(): (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) = get_autotune_configs() +@triton.jit +def _attn_fwd_inner( + q, kT, v, pos, col_mask, + m_i, l_i, acc, + pid_m, + q_descale, k_descale, v_descale, # FP8 scaling factors + IS_FP8: tl.constexpr, # FP8 flag + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + N_CTX_Q: tl.constexpr, + N_CTX_K_FINAL: tl.constexpr, + USE_ALIBI: tl.constexpr, + alibi_slope, + USE_SLIDING_WINDOW: tl.constexpr, + IS_CAUSAL: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + APPLY_COL_MASK: tl.constexpr, # apply provided col_mask when True +): + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_FP8: + qk += (tl.dot(q, kT) * q_descale * k_descale) # Apply FP8 scaling + else: + qk += tl.dot(q, kT) # noqa: F821 + + if USE_ALIBI: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = pos + tl.arange(0, BLOCK_N) + + # Compute relative positions + relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) + relative_pos = tl.abs(relative_pos) + + # Compute ALiBi bias + alibi_bias = -1 * alibi_slope * relative_pos + qk += (alibi_bias * 1.44269504) + + # ------------------------------------------------------------------ + # masking + # ------------------------------------------------------------------ + if USE_SLIDING_WINDOW: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # q positions + col_idx = pos + tl.arange(0, BLOCK_N) # k positions + row = row_idx[:, None] # [M,1] + col = col_idx[None, :] # [1,N] + + if IS_CAUSAL: + # -------- causal + window -------- + diag = N_CTX_K_FINAL - N_CTX_Q # sk-sq + causal_ok = col <= row + diag + if WINDOW_SIZE_LEFT < 0: # only right window + win_ok = col <= row + diag + WINDOW_SIZE_RIGHT + else: # both sides + win_ok = ((col >= row + diag - WINDOW_SIZE_LEFT) & + (col <= row + diag + WINDOW_SIZE_RIGHT)) + mask = ~(causal_ok & win_ok) # True ⇒ -inf + else: + # -------- non-causal window -------- + sk, sq = N_CTX_K_FINAL, N_CTX_Q + if WINDOW_SIZE_LEFT < 0: + mask = col > row + (sk - sq) + WINDOW_SIZE_RIGHT + else: + right = tl.minimum(row + (sk - sq) + WINDOW_SIZE_RIGHT, sk) + left = row + (sk - sq) - WINDOW_SIZE_LEFT + mask = (col > right) | (col < left) + qk = tl.where(mask, float("-inf"), qk) + else: + if IS_CAUSAL: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = pos + tl.arange(0, BLOCK_N) + + # create a N_CTX_Q x kv_len causal mask + col_offset = N_CTX_K_FINAL - N_CTX_Q + causal_mask = row_idx[:, None] >= (col_idx[None, :] - col_offset) + + # Apply the mask + qk = tl.where(causal_mask, qk, float("-inf")) + + # Column mask (tail / variable-length). Instead of recomputing an arange each time, + # we accept a precomputed mask from the caller (col_valid_mask). + if APPLY_COL_MASK: + # Expect col_mask shape: [BLOCK_N]. True where column is within sequence. + qk = tl.where(col_mask[None, :], qk, float("-inf")) + + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) # per-row max so far + + # rows that are *all* -inf after masking + valid = m_i_new > float("-inf") + + # scale previous partial sums safely + alpha = tl.where(valid, tl.math.exp2(m_i - m_i_new), 0.0) + + # subtract the row max only on valid rows + qk = tl.where(valid[:, None], qk - m_i_new[:, None], float("-inf")) + p = tl.math.exp2(qk) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(q.dtype) + + # -- scale and update acc -- + acc *= alpha[:, None] + if IS_FP8: + acc += tl.dot(p.to(v.dtype), v) * v_descale # Apply FP8 scaling for V + else: + acc += tl.dot(p.to(v.dtype), v) + + return m_i, l_i, acc + + # @triton.autotune( # configs=fwd_auto_tune_configs, # key=fwd_autotune_keys, @@ -69,6 +181,9 @@ def _fwd_kernel_splitK( Q, K, V, + Q_Descale, # FP8 descale factors for Q + K_Descale, # FP8 descale factors for K + V_Descale, # FP8 descale factors for V sm_scale, Out_splitK, # [B*H*G, split_k, Mq, K] Metadata, # [B*H*G, 2, split_k, M_ceil] contains [mi, li] @@ -76,6 +191,7 @@ def _fwd_kernel_splitK( V_new, Cache_seqlens, Cache_batch_idx, + Block_table, Alibi_slopes, stride_qz, stride_qm, @@ -110,13 +226,22 @@ def _fwd_kernel_splitK( stride_vn_g, stride_vn_h, stride_vn_d, + stride_bt_b, + stride_bt_s, stride_az, stride_ah, + stride_q_descale_z, # FP8 descale strides + stride_q_descale_h, + stride_k_descale_z, + stride_k_descale_h, + stride_v_descale_z, + stride_v_descale_h, Z, N_CTX_Q, N_CTX_K, N_CTX_NEW, BLOCK_N_PER_SPLIT, + BLOCK_SIZE_K: tl.constexpr, H_q: tl.constexpr, H_kv: tl.constexpr, G_q: tl.constexpr, @@ -136,6 +261,8 @@ def _fwd_kernel_splitK( USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, + USE_BLOCK_TABLE: tl.constexpr, + IS_FP8: tl.constexpr, # FP8 flag ): # get program ids pid_m = tl.program_id(0) @@ -155,14 +282,24 @@ def _fwd_kernel_splitK( hk_id = hq_id hv_id = hq_id + # Load FP8 descale factors if needed + if IS_FP8: + if IS_GQA: + # For MQA/GQA, q_descale uses the same indexing as k/v (hk_id) + q_descale = tl.load(Q_Descale + z_id * stride_q_descale_z + hk_id * stride_q_descale_h) + else: + # For MHA, q_descale uses hq_id + q_descale = tl.load(Q_Descale + z_id * stride_q_descale_z + hq_id * stride_q_descale_h) + k_descale = tl.load(K_Descale + z_id * stride_k_descale_z + hk_id * stride_k_descale_h) + v_descale = tl.load(V_Descale + z_id * stride_v_descale_z + hv_id * stride_v_descale_h) + else: + q_descale, k_descale, v_descale = 1.0, 1.0, 1.0 + # figure out seqlens lo = pid_splitk * BLOCK_N_PER_SPLIT if USE_CACHE_SEQLENs: cache_seqlen_last_idx = tl.load(Cache_seqlens + z_id) - if NEW_KV: - N_CTX_K_FINAL = cache_seqlen_last_idx + N_CTX_NEW - else: - N_CTX_K_FINAL = cache_seqlen_last_idx + N_CTX_K_FINAL = cache_seqlen_last_idx else: N_CTX_K_FINAL = N_CTX_K hi = tl.minimum((pid_splitk + 1) * BLOCK_N_PER_SPLIT, N_CTX_K_FINAL) @@ -181,8 +318,15 @@ def _fwd_kernel_splitK( # compute ptrs q_offset = Q + hq_id * stride_qh + z_id * stride_qz + g_id * stride_qg q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - k_offset = K + hk_id * stride_kh + cache_batch_idx * stride_kz + g_id * stride_kg - v_offset = V + hv_id * stride_vh + cache_batch_idx * stride_vz + g_id * stride_vg + + # Handle block table for paged attention + if USE_BLOCK_TABLE: + # K and V now point to paged cache + # Each batch has its own block table row + block_table_ptr = Block_table + z_id * stride_bt_b + else: + k_offset = K + hk_id * stride_kh + cache_batch_idx * stride_kz + g_id * stride_kg + v_offset = V + hv_id * stride_vh + cache_batch_idx * stride_vz + g_id * stride_vg # compute masks if PADDED_HEAD: @@ -212,160 +356,122 @@ def _fwd_kernel_splitK( else: alibi_slope = None - # Copy new Keys and Values into Cache - if NEW_KV: - knew_base = K_new + hk_id * stride_kn_h + z_id * stride_kn_z + g_id * stride_kn_g - - # Determine the starting position for new data in the cache - if USE_CACHE_SEQLENs: - start_idx = tl.load(Cache_seqlens + z_id) - else: - start_idx = N_CTX_K - N_CTX_NEW - - # Copy new Keys - for i in range(0, N_CTX_NEW, BLOCK_N): - # Load from K_new - k_new_block = tl.load( - knew_base + - tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + - (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, - mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), - other=0 - ) - - # Store to K - tl.store( - k_offset + - tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd + - (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn, - k_new_block, - mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), - ) - - # Copy new Values - vnew_base = V_new + hv_id * stride_vn_h + z_id * stride_vn_z + g_id * stride_vn_g - for i in range(0, N_CTX_NEW, BLOCK_N): - # Load from V_new - v_new_block = tl.load( - vnew_base + - (tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n + - tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d, - mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), - other=0 - ) - - # Store to V - tl.store( - v_offset + - (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn + - tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd, - v_new_block, - mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), - ) - - # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821 - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - kT_ptrs = k_offset + offs_d[:, None] * stride_kd + (start_n + offs_n)[None, :] * stride_kn - V_ptrs = v_offset + (start_n + offs_n)[:, None] * stride_vn + offs_d[None, :] * stride_vd - - # load k - kT = tl.load(kT_ptrs, mask=kT_mask, other=0.0) - v = tl.load(V_ptrs, mask=v_mask, other=0.0) - - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, kT) # noqa: F821 - - if USE_ALIBI: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = start_n + tl.arange(0, BLOCK_N) - - # Compute relative positions - relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) - relative_pos = tl.abs(relative_pos) + if USE_BLOCK_TABLE: + # Paged attention: process all KV blocks from cache + # Note: Cache should be updated externally before calling this kernel + num_kv_blocks = (N_CTX_K_FINAL + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K + + for block_idx in range(num_kv_blocks): + # Calculate sequence range for this block + block_start = block_idx * BLOCK_SIZE_K + block_end = tl.minimum(block_start + BLOCK_SIZE_K, N_CTX_K_FINAL) - # Compute ALiBi bias - alibi_bias = -1 * alibi_slope * relative_pos - qk += (alibi_bias * 1.44269504) - - # ------------------------------------------------------------------ - # masking - # ------------------------------------------------------------------ - if USE_SLIDING_WINDOW: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # q positions - col_idx = start_n + tl.arange(0, BLOCK_N) # k positions - row = row_idx[:, None] # [M,1] - col = col_idx[None, :] # [1,N] - - if IS_CAUSAL: - # -------- causal + window -------- - diag = N_CTX_K_FINAL - N_CTX_Q # sk-sq - causal_ok = col <= row + diag - if WINDOW_SIZE_LEFT < 0: # only right window - win_ok = col <= row + diag + WINDOW_SIZE_RIGHT - else: # both sides - win_ok = ((col >= row + diag - WINDOW_SIZE_LEFT) & - (col <= row + diag + WINDOW_SIZE_RIGHT)) - mask = ~(causal_ok & win_ok) # True ⇒ -inf - else: - # -------- non-causal window -------- - sk, sq = N_CTX_K_FINAL, N_CTX_Q - if WINDOW_SIZE_LEFT < 0: - mask = col > row + (sk - sq) + WINDOW_SIZE_RIGHT + # Check if block overlaps with our split-k range [lo, hi) + if block_end > lo and block_start < hi: + # Load physical block number + physical_block = tl.load(block_table_ptr + block_idx * stride_bt_s) + + # Calculate the range within this block that overlaps with [lo, hi) + process_start = tl.maximum(lo - block_start, 0) + process_end = tl.minimum(hi - block_start, BLOCK_SIZE_K) + process_end = tl.minimum(process_end, block_end - block_start) + + # Instead of forcing a floor alignment to BLOCK_N (which can still skip + # part of the intended range if start falls mid-tile for small splits), + # start from the raw (possibly unaligned) process_start rounded *down* but + # allow the loop to begin earlier (at most BLOCK_N before) so that any + # partial tile overlapping [lo, hi) is covered. Masking below will remove + # columns < lo or >= hi ensuring numerically identical coverage without + # duplication. + aligned_start = (process_start // BLOCK_N) * BLOCK_N + if aligned_start > 0 and aligned_start + BLOCK_N > process_start: + # ensure we include the tile that contains process_start + process_start = aligned_start else: - right = tl.minimum(row + (sk - sq) + WINDOW_SIZE_RIGHT, sk) - left = row + (sk - sq) - WINDOW_SIZE_LEFT - mask = (col > right) | (col < left) - qk = tl.where(mask, float("-inf"), qk) - else: - if IS_CAUSAL: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = start_n + tl.arange(0, BLOCK_N) - - # create a N_CTX_Q x kv_len causal mask - col_offset = N_CTX_Q - N_CTX_K_FINAL - causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) - - # Apply the mask - qk = tl.where(causal_mask, qk, float("-inf")) - - # TODO: This is slow, and only needed at the last iteration. - # Maybe we can unroll the last iteration instead? - if BOUNDS_CHECKS_N: - qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) - - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) # per-row max so far - - # rows that are *all* -inf after masking - valid = m_i_new > float("-inf") - - # scale previous partial sums safely - alpha = tl.where(valid, tl.math.exp2(m_i - m_i_new), 0.0) - - # subtract the row max only on valid rows - qk = tl.where(valid[:, None], qk - m_i_new[:, None], float("-inf")) - p = tl.math.exp2(qk) - - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - p = p.to(Q.dtype.element_ty) - - # -- scale and update acc -- - acc *= alpha[:, None] - acc += tl.dot(p.to(v.dtype), v) + process_start = aligned_start + + for offset in range(process_start, process_end, BLOCK_N): + # Current position (may begin slightly before logical split range; masking fixes it) + pos = block_start + offset + # Proceed unconditionally; masking below enforces [lo, hi) + # Calculate base addresses for K and V in this physical block + k_base = K + physical_block * BLOCK_SIZE_K * stride_kn + hk_id * stride_kh + g_id * stride_kg + v_base = V + physical_block * BLOCK_SIZE_K * stride_vn + hv_id * stride_vh + g_id * stride_vg + + # Offsets within the current block + block_offs = offset + offs_n + + # Masks for valid data respecting: + # (1) global key length (seq_mask) + # (2) block bounds (block_mask) + # (3) current split range [lo, hi) + seq_mask = ((pos + offs_n) < N_CTX_K_FINAL) + block_mask = (block_offs < BLOCK_SIZE_K) + end_mask = (block_offs < process_end) + split_mask = ((pos + offs_n) >= lo) & ((pos + offs_n) < hi) + col_mask = seq_mask & block_mask & end_mask & split_mask + + # Apply masks + kT_mask_final = kT_mask & col_mask[None, :] + v_mask_final = v_mask & col_mask[:, None] + + # Load K and V + kT_ptrs = k_base + offs_d[:, None] * stride_kd + block_offs[None, :] * stride_kn + v_ptrs = v_base + block_offs[:, None] * stride_vn + offs_d[None, :] * stride_vd + + kT = tl.load(kT_ptrs, mask=kT_mask_final, other=0.0) + v = tl.load(v_ptrs, mask=v_mask_final, other=0.0) + + # Unified inner function handles both paged and contiguous + m_i, l_i, acc = _attn_fwd_inner( + q, kT, v, pos, col_mask, + m_i, l_i, acc, + pid_m, + q_descale, k_descale, v_descale, + IS_FP8, + BLOCK_M, BLOCK_N, + N_CTX_Q, N_CTX_K_FINAL, + USE_ALIBI, alibi_slope, + USE_SLIDING_WINDOW, IS_CAUSAL, + WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, + True, + ) + else: + # Non-paged attention: process KV from cache + # Note: Cache should be updated externally before calling this kernel + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + kT_ptrs = k_offset + offs_d[:, None] * stride_kd + (start_n + offs_n)[None, :] * stride_kn + V_ptrs = v_offset + (start_n + offs_n)[:, None] * stride_vn + offs_d[None, :] * stride_vd + + # load k + kT = tl.load(kT_ptrs, mask=kT_mask, other=0.0) + v = tl.load(V_ptrs, mask=v_mask, other=0.0) + + # Use the same inner loop logic + # Precompute column validity mask for this tile (all True for full tiles). + # hi is the upper bound of the overall split range; start_n marks this tile's base. + col_valid_mask = offs_n < (hi - start_n) + + m_i, l_i, acc = _attn_fwd_inner( + q, kT, v, start_n, col_valid_mask, + m_i, l_i, acc, + pid_m, + q_descale, k_descale, v_descale, + IS_FP8, + BLOCK_M, BLOCK_N, + N_CTX_Q, N_CTX_K_FINAL, + USE_ALIBI, alibi_slope, + USE_SLIDING_WINDOW, IS_CAUSAL, + WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, + BOUNDS_CHECKS_N, + ) # write back O osk_offset = Out_splitK + pid_zhg * stride_osk_zhg + pid_splitk * stride_osk_s @@ -598,7 +704,71 @@ def attention_decode_forward_triton_impl( layout: Literal["bshd"], cache_seqlens: Optional[torch.Tensor], cache_batch_idx: Optional[torch.Tensor], + block_table: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, ): + # handle cache updates + if k_new is not None and v_new is not None: + # Update cache with new KV values + if block_table is None: + # Non-paged attention: update cache directly + batch_size = k_new.shape[0] + seqlen_new = k_new.shape[1] + + if cache_seqlens is not None: + # Use cache_seqlens to determine where to insert new KV + for b in range(batch_size): + start_idx = int(cache_seqlens[b].item()) + end_idx = start_idx + seqlen_new + k_cache[b, start_idx:end_idx] = k_new[b] + v_cache[b, start_idx:end_idx] = v_new[b] + cache_seqlens[b] = end_idx + else: + # Append at the end of existing cache + seqlen_cache = k_cache.shape[1] + k_cache[:, seqlen_cache - seqlen_new:] = k_new + v_cache[:, seqlen_cache - seqlen_new:] = v_new + else: + # Paged attention: update cache using block table + batch_size = k_new.shape[0] + seqlen_new = k_new.shape[1] + block_size = k_cache.shape[1] # k_cache shape: [num_blocks, block_size, nheads, head_dim] + + # Update cache for each batch element + for b in range(batch_size): + if cache_seqlens is not None: + start_idx = int(cache_seqlens[b].item()) + else: + # If no cache_seqlens, assume we're appending at the end + # Find the last used position from block table + start_idx = 0 + for block_idx in range(block_table.shape[1]): + if block_table[b, block_idx] >= 0: + start_idx = (block_idx + 1) * block_size + else: + start_idx = block_idx * block_size + break + + # Copy new KV values into the paged cache + for i in range(seqlen_new): + pos = start_idx + i + block_idx = pos // block_size + within_block_idx = pos % block_size + + # Get the physical block number from block table + if block_idx < block_table.shape[1]: + physical_block = int(block_table[b, block_idx].item()) + + # Update k_cache and v_cache at the physical block location + k_cache[physical_block, within_block_idx] = k_new[b, i] + v_cache[physical_block, within_block_idx] = v_new[b, i] + + # Update cache_seqlens if provided + if cache_seqlens is not None: + cache_seqlens[b] = start_idx + seqlen_new + # triton configs BLOCK_M = 16 BLOCK_N = 64 @@ -607,17 +777,45 @@ def attention_decode_forward_triton_impl( num_warps_reduce = 4 # kernel_configs - is_new_kv = True if k_new is not None and v_new is not None else False + is_new_kv = False # Cache has been updated, so no new KV in kernel use_alibi, (stride_az, stride_ah) = True if alibi_slopes is not None else False, alibi_slopes.stride() if alibi_slopes is not None else (None, None) use_cache_seqlens = cache_seqlens is not None use_sliding_window = window_size_left != -1 or window_size_right != -1 + use_block_table = block_table is not None SPLIT_K = None NUM_QUANT_GROUPS = 1 # get shapes and strides (batch_size, seqlen_q, nheads_q, dim_q), (stride_qz, stride_qh, stride_qm, stride_qd) = get_shape_and_strides_from_layout(q, layout) - (_, seqlen_kc, nheads_kc, dim_kc), (stride_kc_z, stride_kc_h, stride_kc_n, stride_kc_d) = get_shape_and_strides_from_layout(k_cache, layout) - (_, seqlen_vc, nheads_vc, dim_vc), (stride_vc_z, stride_vc_h, stride_vc_n, stride_vc_d) = get_shape_and_strides_from_layout(v_cache, layout) + + # Handle paged KV cache layout + if use_block_table: + # For paged attention, k_cache and v_cache have shape [num_blocks, block_size, nheads, head_dim] + num_blocks_kc, block_size_k, nheads_kc, dim_kc = k_cache.shape + num_blocks_vc, block_size_v, nheads_vc, dim_vc = v_cache.shape + # Get the actual sequence length from cache_seqlens or block_table + if cache_seqlens is not None: + seqlen_kc = int(cache_seqlens.max().item()) + else: + # Infer from block_table shape [batch_size, num_blocks_per_seq] + num_blocks_per_seq = block_table.shape[1] + seqlen_kc = num_blocks_per_seq * block_size_k + seqlen_vc = seqlen_kc + + # Strides for paged layout + stride_kc_z = 0 # No batch dimension in paged cache + stride_kc_n = k_cache.stride(1) # Sequence stride + stride_kc_h = k_cache.stride(2) # Head stride + stride_kc_d = k_cache.stride(3) # Dim stride + + stride_vc_z = 0 + stride_vc_n = v_cache.stride(1) + stride_vc_h = v_cache.stride(2) + stride_vc_d = v_cache.stride(3) + else: + (_, seqlen_kc, nheads_kc, dim_kc), (stride_kc_z, stride_kc_h, stride_kc_n, stride_kc_d) = get_shape_and_strides_from_layout(k_cache, layout) + (_, seqlen_vc, nheads_vc, dim_vc), (stride_vc_z, stride_vc_h, stride_vc_n, stride_vc_d) = get_shape_and_strides_from_layout(v_cache, layout) + block_size_k = 0 # Not used if is_new_kv: ( _, seqlen_kn, nheads_kn, dim_kn), (stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d) = get_shape_and_strides_from_layout(k_new, layout) (_, seqlen_vn, nheads_vn, dim_vn), (stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d) = get_shape_and_strides_from_layout(v_new, layout) @@ -657,7 +855,12 @@ def attention_decode_forward_triton_impl( split_k = SPLIT_K else: # Use heuristics - split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) # NOTE: should the split think about seqlens? + if use_block_table: + # For paged attention, use the actual sequence length from cache_seqlens + max_seqlen = int(cache_seqlens.max().item()) if cache_seqlens is not None else block_size_k + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, max_seqlen) + else: + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) split_size = (seqlen_kc + split_k - 1) // split_k # setup grid @@ -673,6 +876,39 @@ def attention_decode_forward_triton_impl( stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d = out_splitk.stride() stride_mzhg, stride_m2, stride_ms, stride_mm = metadata.stride() stride_lse_zhg, stride_lse_m = lse.stride() + + # Block table strides + if use_block_table: + stride_bt_b, stride_bt_s = block_table.stride() + else: + stride_bt_b, stride_bt_s = 0, 0 + + # FP8 support + IS_FP8 = is_fp8(q) + if IS_FP8: + if (q_descale is None) or (k_descale is None) or (v_descale is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) + # Create default descale tensors if not provided + if q_descale is None: + q_descale = torch.ones(batch_size, nheads_q, dtype=torch.float32, device=q.device) + if k_descale is None: + k_descale = torch.ones(batch_size, nheads_kc, dtype=torch.float32, device=q.device) + if v_descale is None: + v_descale = torch.ones(batch_size, nheads_vc, dtype=torch.float32, device=q.device) + stride_q_descale_z, stride_q_descale_h = q_descale.stride() + stride_k_descale_z, stride_k_descale_h = k_descale.stride() + stride_v_descale_z, stride_v_descale_h = v_descale.stride() + else: + q_descale = None + k_descale = None + v_descale = None + stride_q_descale_z = 0 + stride_q_descale_h = 0 + stride_k_descale_z = 0 + stride_k_descale_h = 0 + stride_v_descale_z = 0 + stride_v_descale_h = 0 if DEBUG: print("batch_size, seqlen_q, nheads_q, dim_q", (batch_size, seqlen_q, nheads_q, dim_q)) @@ -689,18 +925,21 @@ def attention_decode_forward_triton_impl( print("stride_mzhg, stride_m2, stride_ms, stride_mm", (stride_mzhg, stride_m2, stride_ms, stride_mm)) print("stride_lse_zhg, stride_lse_m", (stride_lse_zhg, stride_lse_m)) - # TODO: enable quantization _fwd_kernel_splitK[grid]( Q=q, K=k_cache, V=v_cache, + Q_Descale=q_descale, + K_Descale=k_descale, + V_Descale=v_descale, sm_scale=sm_scale, Out_splitK=out_splitk, Metadata=metadata, - K_new=k_new, - V_new=v_new, + K_new=None, + V_new=None, Cache_seqlens=cache_seqlens, Cache_batch_idx=cache_batch_idx, + Block_table=block_table, Alibi_slopes=alibi_slopes, # q strides stride_qz=stride_qz, @@ -742,17 +981,28 @@ def attention_decode_forward_triton_impl( stride_vn_g=stride_vn_g, stride_vn_h=stride_vn_h, stride_vn_d=stride_vn_d, + # block table strides + stride_bt_b=stride_bt_b, + stride_bt_s=stride_bt_s, # alibi strides stride_az=stride_az, stride_ah=stride_ah, + # FP8 descale strides + stride_q_descale_z=stride_q_descale_z, + stride_q_descale_h=stride_q_descale_h, + stride_k_descale_z=stride_k_descale_z, + stride_k_descale_h=stride_k_descale_h, + stride_v_descale_z=stride_v_descale_z, + stride_v_descale_h=stride_v_descale_h, Z=batch_size, H_q=heads_per_group_q, H_kv=heads_per_group_k, G_q=n_group_q, N_CTX_Q=seqlen_q, N_CTX_K=seqlen_kc, - N_CTX_NEW=seqlen_kn, + N_CTX_NEW=0, # No new KV, cache already updated BLOCK_N_PER_SPLIT=split_size, + BLOCK_SIZE_K=block_size_k if use_block_table else 256, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=dim_padded, @@ -760,7 +1010,7 @@ def attention_decode_forward_triton_impl( BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, USE_CACHE_SEQLENs=use_cache_seqlens, USE_CACHE_BATCH_IDX=cache_batch_idx is not None, - NEW_KV=is_new_kv, + NEW_KV=False, # Cache already updated IS_GQA=is_gqa, IS_CAUSAL=causal, USE_ALIBI=use_alibi, @@ -769,6 +1019,8 @@ def attention_decode_forward_triton_impl( USE_SLIDING_WINDOW=use_sliding_window, WINDOW_SIZE_LEFT=window_size_left, WINDOW_SIZE_RIGHT=window_size_right, + USE_BLOCK_TABLE=use_block_table, + IS_FP8=IS_FP8, num_warps=num_warps_fwd, num_stages=num_stages, ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py old mode 100644 new mode 100755 index 3a2bd56fda4..bb0301c7700 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -3,14 +3,99 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_arch, is_cdna, is_fp8, is_rdna, create_dropout_mask, get_fwd_prefill_autotune_configs - -DEBUG = False +from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_arch, is_cdna, is_fp8, is_rdna, create_dropout_mask, DEBUG # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) +# ------------------------------- +# Autotune +# ------------------------------- +def get_fwd_prefill_cdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL_QK', 'ACTUAL_BLOCK_DMODEL_V', 'IS_VARLEN', 'HQ', 'HK'] + + +def get_fwd_prefill_rdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL_QK', 'ACTUAL_BLOCK_DMODEL_V', 'IS_VARLEN', 'HQ', 'HK'] + + +def get_fwd_prefill_autotune_configs(): + if AUTOTUNE: + if is_rdna(): + return get_fwd_prefill_rdna_autotune_configs() + elif is_cdna(): + return get_fwd_prefill_cdna_autotune_configs() + else: + raise ValueError("Unknown Device Type") + else: + arch = get_arch() + if arch == "gfx950": + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + elif arch == "gfx942" and False: # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + else: + default_config = triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + + return [ + default_config + ], [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL_QK", + "ACTUAL_BLOCK_DMODEL_V", + "IS_VARLEN", + "HQ", + "HK", + ] + + fwd_prefill_autotune_configs, fwd_prefill_autotune_keys = get_fwd_prefill_autotune_configs() @triton.jit @@ -20,13 +105,14 @@ def _attn_fwd_no_mask(acc, l_i, m_i, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_base_ptrs, sd_mask_base_ptrs, dropout_mask_base_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, block_max, alibi_slope, - descale_q, descale_k, descale_v, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + q_descale, k_descale, v_descale, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL_QK: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD_QK: tl.constexpr, PADDED_HEAD_V: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, RETURN_SCORES: tl.constexpr, ACCUMULATOR_TYPE): if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 @@ -38,21 +124,27 @@ def _attn_fwd_no_mask(acc, l_i, m_i, v_ptrs = v_base_ptrs + start_n * stride_vk kv_offs_n = start_n + tl.arange(0, BLOCK_N) - if PADDED_HEAD: - k_mask, k_mask_other = (offs_d[:, None] < ACTUAL_BLOCK_DMODEL), 0.0 - v_mask, v_mask_other = (offs_d[None, :] < ACTUAL_BLOCK_DMODEL), 0.0 + if PADDED_HEAD_QK: + k_mask, k_mask_other = (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK), 0.0 + else: + k_mask, k_mask_other = None, None + + if PADDED_HEAD_V: + v_mask, v_mask_other = (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V), 0.0 + else: + v_mask, v_mask_other = None, None # load k and if preload_v then v - k = tl.load(k_ptrs, mask=k_mask, other=k_mask_other) if PADDED_HEAD else tl.load(k_ptrs) + k = tl.load(k_ptrs, mask=k_mask, other=k_mask_other) if PADDED_HEAD_QK else tl.load(k_ptrs) if PRE_LOAD_V: - v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD else tl.load(v_ptrs) + v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD_V else tl.load(v_ptrs) # setup qk accumlator qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) # -- compute qk ---- if IS_FP8 : - qk += (tl.dot(q, k) * descale_q * descale_k) + qk += (tl.dot(q, k) * q_descale * k_descale) else: qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE @@ -124,15 +216,18 @@ def _attn_fwd_no_mask(acc, l_i, m_i, alpha = tl.math.exp(m_diff) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD else tl.load(v_ptrs) + v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD_V else tl.load(v_ptrs) # -- update m_i and l_i l_i = l_i * alpha + l_ij m_i = m_ij if IS_FP8: - scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) + if FP8_P_DESCALE: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * v_descale) + else: + acc += tl.dot(p.to(v.type.element_ty), v) * v_descale else: acc += tl.dot(p.to(v.type.element_ty), v) @@ -145,13 +240,14 @@ def _attn_fwd_mask(acc, l_i, m_i, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_base_ptrs, sd_mask_base_ptrs, dropout_mask_base_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, block_max, n_extra_tokens, alibi_slope, - descale_q, descale_k, descale_v, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, - IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + q_descale, k_descale, v_descale, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL_QK: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD_QK: tl.constexpr, PADDED_HEAD_V: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, ACCUMULATOR_TYPE): @@ -172,9 +268,10 @@ def _attn_fwd_mask(acc, l_i, m_i, kv_offs_n = start_n + tl.arange(0, BLOCK_N) k_mask = (kv_offs_n[None, :] < seqlen_k) v_mask = (kv_offs_n[:, None] < seqlen_k) - if PADDED_HEAD: - k_mask = k_mask & (offs_d[:, None] < ACTUAL_BLOCK_DMODEL) - v_mask = v_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + if PADDED_HEAD_QK: + k_mask = k_mask & (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK) + if PADDED_HEAD_V: + v_mask = v_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) # load k and if preload_v then v k = tl.load(k_ptrs, mask=k_mask, other = 0.0) @@ -200,7 +297,7 @@ def _attn_fwd_mask(acc, l_i, m_i, # -- compute qk ---- if IS_FP8 : - qk += (tl.dot(q, k) * descale_q * descale_k) + qk += (tl.dot(q, k) * q_descale * k_descale) else: qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE @@ -373,8 +470,11 @@ def _attn_fwd_mask(acc, l_i, m_i, m_i = m_ij if IS_FP8: - scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) + if FP8_P_DESCALE: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * v_descale) + else: + acc += tl.dot(p.to(v.type.element_ty), v) * v_descale else: acc += tl.dot(p.to(v.type.element_ty), v) @@ -626,18 +726,19 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, ) @triton.jit def attn_fwd(Q, K, V, bias, - Descale_Q, Descale_K, Descale_V, Descale_O, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, + Q_Descale, K_Descale, V_Descale, stride_q_descale_z, stride_k_descale_z, stride_v_descale_z, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Add seqused parameters dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, - HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, ACTUAL_BLOCK_DMODEL_V: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, NEEDS_SDMASK : tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr): + IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, USE_SEQUSED: tl.constexpr): # set params ACCUMULATOR_TYPE = tl.float32 @@ -652,24 +753,38 @@ def attn_fwd(Q, K, V, bias, else: off_h_k = off_h_q # Determine if we need to mask the heads - PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_BLOCK_DMODEL_QK != BLOCK_DMODEL_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_BLOCK_DMODEL_V != BLOCK_DMODEL_V) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d_qk = tl.arange(0, BLOCK_DMODEL_QK) + offs_d_v = tl.arange(0, BLOCK_DMODEL_V) # handle seqlen if IS_VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = tl.load(seqused_q + off_z) if seqused_q is not None else cu_seqlens_q_end - cu_seqlens_q_start + seqlen_q = tl.minimum(actual_seqlen_q, cu_seqlens_q_end - cu_seqlens_q_start) + else: + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start # we have a one-size-fits-all grid in id(0). Some seqlens might be too small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) - seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + + # If seqused is provided, use it to limit the actual sequence length for keys + if USE_SEQUSED: + actual_seqlen_k = tl.load(seqused_k + off_z) if seqused_k is not None else cu_seqlens_k_end - cu_seqlens_k_start + seqlen_k = tl.minimum(actual_seqlen_k, cu_seqlens_k_end - cu_seqlens_k_start) + else: + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start else: cu_seqlens_q_start = 0 cu_seqlens_k_start = 0 @@ -678,11 +793,16 @@ def attn_fwd(Q, K, V, bias, # Load scale factors if IS_FP8. if IS_FP8: - descale_q = tl.load(Descale_Q + off_z * stride_descale_q_z + off_h_q) - descale_k = tl.load(Descale_K + off_z * stride_descale_k_z + off_h_k) - descale_v = tl.load(Descale_V + off_z * stride_descale_v_z + off_h_k) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (off_h_k) + # For MHA (GROUP_SIZE == 1), q_descale uses off_h_q (same as off_h_k) + if GROUP_SIZE != 1: + q_descale = tl.load(Q_Descale + off_z * stride_q_descale_z + off_h_k) # MQA/GQA: broadcast using k/v head index + else: + q_descale = tl.load(Q_Descale + off_z * stride_q_descale_z + off_h_q) # MHA: use q head index + k_descale = tl.load(K_Descale + off_z * stride_k_descale_z + off_h_k) + v_descale = tl.load(V_Descale + off_z * stride_v_descale_z + off_h_k) else: - descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + q_descale, k_descale, v_descale = 1.0, 1.0, 1.0 # figure out masking pattern @@ -701,11 +821,11 @@ def attn_fwd(Q, K, V, bias, """ # Write zeros to output o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d_v[None, :] * stride_on o_mask = (offs_m[:, None] < seqlen_q) - if PADDED_HEAD: - o_mask = o_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) - tl.store(o_ptrs, tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty), mask=o_mask) + if PADDED_HEAD_V: + o_mask = o_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) + tl.store(o_ptrs, tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=Out.type.element_ty), mask=o_mask) # Write zeros to LSE l_ptrs = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m + offs_m * stride_lse_m @@ -723,11 +843,11 @@ def attn_fwd(Q, K, V, bias, # Initialize for processing # Compute pointers for all the tensors used in this kernel. q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qk k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + k_ptrs = k_offset + offs_d_qk[:, None] * stride_kk + offs_n[None, :] * stride_kn v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d_v[None, :] * stride_vn if USE_BIAS: # Note: this might get large enough to overflow on some configs bias_offset = off_h_q * stride_bh @@ -759,12 +879,12 @@ def attn_fwd(Q, K, V, bias, # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=ACCUMULATOR_TYPE) l_i = tl.full([BLOCK_M], 1.0, dtype=ACCUMULATOR_TYPE) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=ACCUMULATOR_TYPE) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=ACCUMULATOR_TYPE) # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + if PADDED_HEAD_QK: + q_ptrs_mask = q_ptrs_mask & (offs_d_qk[None, :] < ACTUAL_BLOCK_DMODEL_QK) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) @@ -781,17 +901,17 @@ def attn_fwd(Q, K, V, bias, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, # Start of front masked blocks block_max, # End of front masked blocks 0, # n_extra_tokens (0 for front blocks, only relevant for last block) alibi_slope, - descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, + q_descale, k_descale, v_descale, IS_FP8, FP8_MAX, FP8_P_DESCALE, IS_CAUSAL, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, + BLOCK_M, BLOCK_DMODEL_QK, BLOCK_DMODEL_V, BLOCK_N, PRE_LOAD_V, - ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, + ENABLE_DROPOUT, PADDED_HEAD_QK, PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, ACTUAL_BLOCK_DMODEL_V, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, @@ -812,15 +932,15 @@ def attn_fwd(Q, K, V, bias, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, # Start of range: 0 block_max, # End of range: n_full_blocks * BLOCK_N alibi_slope, - descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, + q_descale, k_descale, v_descale, IS_FP8, FP8_MAX, FP8_P_DESCALE, + BLOCK_M, BLOCK_DMODEL_QK, BLOCK_DMODEL_V, BLOCK_N, PRE_LOAD_V, - ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, + ENABLE_DROPOUT, PADDED_HEAD_QK, PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, ACTUAL_BLOCK_DMODEL_V, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE ) @@ -837,17 +957,17 @@ def attn_fwd(Q, K, V, bias, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, # Start of range: n_full_blocks * BLOCK_N block_max, # End of range: n_visible_k_blocks * BLOCK_N n_extra_tokens, # Padding tokens in last block alibi_slope, - descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, + q_descale, k_descale, v_descale, IS_FP8, FP8_MAX, FP8_P_DESCALE, IS_CAUSAL, # Use actual causal flag - BLOCK_M, BLOCK_DMODEL, BLOCK_N, + BLOCK_M, BLOCK_DMODEL_QK, BLOCK_DMODEL_V, BLOCK_N, PRE_LOAD_V, - ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, + ENABLE_DROPOUT, PADDED_HEAD_QK, PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, ACTUAL_BLOCK_DMODEL_V, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, @@ -933,7 +1053,7 @@ def attn_fwd(Q, K, V, bias, end_m_idx = (start_m + 1) * BLOCK_M if causal_start_idx < end_m_idx: # This block contains the boundary - need to mask acc - out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + out_mask_boundary = tl.full((BLOCK_DMODEL_V, ), causal_start_idx, dtype=tl.int32) out_ptrs_mask = row_indices[:, None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) @@ -958,19 +1078,14 @@ def attn_fwd(Q, K, V, bias, # write back O o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on - o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d_v[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL_V], 1, dtype=tl.int1) if overflow_size > 0: o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) - if PADDED_HEAD: - o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + if PADDED_HEAD_V: + o_ptrs_mask = o_ptrs_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) - if FP8_OUTPUT: - scale_acc, descale_acc = compute_fp8_scaling_factors(acc, FP8_MAX) - tl.store(Descale_O + off_z * stride_descale_o_z + off_h_q, descale_acc) - tl.store(o_ptrs, (acc * scale_acc).to(Out.type.element_ty), mask=o_ptrs_mask) - else: - tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) def attention_prefill_forward_triton_impl( q: torch.Tensor, @@ -997,10 +1112,12 @@ def attention_prefill_forward_triton_impl( return_softmax: bool, use_exp2: bool, # fp8 - descale_q: Optional[torch.Tensor], - descale_k: Optional[torch.Tensor], - descale_v: Optional[torch.Tensor], - descale_o: Optional[torch.Tensor], + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + # seqused for FA v3 + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, ): # get params, strides and shape IS_VARLEN = layout == "thd" @@ -1027,12 +1144,12 @@ def attention_prefill_forward_triton_impl( assert max_seqlens_k is not None and max_seqlens_k > 0, "max_seqlens_k must be provided and positive for varlen layout" # assert head dimensions - assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" # assert output shapes - assert o.shape == (total_seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_q)}" + assert o.shape == (total_seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" # assert cu_seqlens assert cu_seqlens_q.dtype == torch.int32, f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" @@ -1044,7 +1161,7 @@ def attention_prefill_forward_triton_impl( # set vars batch = len(cu_seqlens_q) - 1 - head_size = head_size_q + head_size_qk = head_size_q # softmax_lse shape softmax_lse = torch.zeros((nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32) @@ -1065,7 +1182,7 @@ def attention_prefill_forward_triton_impl( assert batch_q == batch_k == batch_v, f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" # assert head dimensions - assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" @@ -1073,11 +1190,11 @@ def attention_prefill_forward_triton_impl( assert seqlen_k == seqlen_v, f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" # assert output shapes - assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(batch_q, seqlen_q, nheads_q, head_size_q)}" + assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected {(batch_q, seqlen_q, nheads_q, head_size_v)}" # set vars batch = batch_q - head_size = head_size_q + head_size_qk = head_size_q max_seqlens_q = seqlen_q max_seqlens_k = seqlen_k @@ -1098,29 +1215,32 @@ def attention_prefill_forward_triton_impl( FP8_MAX = torch.finfo(q.dtype).max - # Check descale tensors - assert descale_q is not None, "descale_q must be provided when using fp8" - assert descale_k is not None, "descale_k must be provided when using fp8" - assert descale_v is not None, "descale_v must be provided when using fp8" + # Check and create default descale tensors if not provided + if (q_descale is None) or (k_descale is None) or (v_descale is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) + # Create default descale tensors if not provided + if q_descale is None: + q_descale = torch.ones(batch, nheads_q, dtype=torch.float32, device=q.device) + if k_descale is None: + k_descale = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) + if v_descale is None: + v_descale = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) - if is_fp8(o): - FP8_OUTPUT = True - assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." - else: - FP8_OUTPUT = False - # o should be fp32 or fp16/bf16 - assert o.dtype in [torch.float16, torch.bfloat16, torch.float32], \ - f"Output tensor o must be fp16, bf16, or fp32 when using fp8, got {o.dtype}" - - stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None - stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None - stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None - stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None + # o should be fp32 or fp16/bf16 + assert o.dtype in [torch.float16, torch.bfloat16, torch.float32], \ + f"Output tensor o must be fp16, bf16, or fp32 when using fp8, got {o.dtype}" + + stride_q_descale_z = q_descale.stride(0) if q_descale is not None else 0 + stride_k_descale_z = k_descale.stride(0) if k_descale is not None else 0 + stride_v_descale_z = v_descale.stride(0) if v_descale is not None else 0 + + if DEBUG: + print(f"FP8 path triggered in fwd_prefill.py") else: FP8_MAX = None - FP8_OUTPUT = False - descale_q = descale_k = descale_v = descale_o = None - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = None + q_descale = k_descale = v_descale = None + stride_q_descale_z = stride_k_descale_z = stride_v_descale_z = None # check output dtype matches input dtype when not using fp8 assert o.dtype == q.dtype, f"Output dtype {o.dtype} must match input dtype {q.dtype} when not using fp8" @@ -1132,11 +1252,13 @@ def attention_prefill_forward_triton_impl( if (bias is not None): assert (bias.numel() < 2**31) - # Get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() + # Get closest power of 2 over or equal to 32 for both QK and V dimensions + padded_d_model_qk = 1 << (head_size_qk - 1).bit_length() + padded_d_model_v = 1 << (head_size_v - 1).bit_length() # Smallest head_dim supported is 16. If smaller, the tile in the # kernel is padded - there is no padding in memory for any dims. - padded_d_model = max(padded_d_model, 16) + padded_d_model_qk = max(padded_d_model_qk, 16) + padded_d_model_v = max(padded_d_model_v, 16) # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according @@ -1166,7 +1288,7 @@ def attention_prefill_forward_triton_impl( # launch kernel grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META['BLOCK_M'])) attn_fwd[grid](q, k, v, bias, - descale_q, descale_k, descale_v, descale_o, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, + q_descale, k_descale, v_descale, stride_q_descale_z, stride_k_descale_z, stride_v_descale_z, sm_scale, softmax_lse, o, stride_qb, stride_qh, stride_qm, stride_qd, stride_kb, stride_kh, stride_kn, stride_kd, @@ -1177,13 +1299,15 @@ def attention_prefill_forward_triton_impl( stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Pass seqused tensors dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, - HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, + HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL_QK=head_size_qk, ACTUAL_BLOCK_DMODEL_V=head_size_v, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, USE_SLIDING_WINDOW=use_sliding_window, WINDOW_SIZE_LEFT=window_size_left, WINDOW_SIZE_RIGHT=window_size_right, IS_VARLEN=IS_VARLEN, - BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, + BLOCK_DMODEL_QK=padded_d_model_qk, BLOCK_DMODEL_V=padded_d_model_v, USE_BIAS=False if bias is None else True, USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, NEEDS_SDMASK=NEEDS_SDMASK, - IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT) + IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_P_DESCALE=False, + USE_SEQUSED=(seqused_q is not None or seqused_k is not None)) # Add flag for seqused return softmax_lse, sd_mask if return_softmax else None \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py old mode 100644 new mode 100755 index 2265af12096..a8ca54a7ec3 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/fwd_ref.py @@ -9,7 +9,7 @@ def attention_forward_core_ref_impl( q, k, v, sm_scale, causal, window_size_left, window_size_right, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2, - cache_seqlens=None + cache_seqlens=None, block_table=None, paged_kv_block_size=None ): if DEBUG_CORE: print() @@ -26,17 +26,99 @@ def attention_forward_core_ref_impl( print("philox_offset:", philox_offset) print("use_exp2:", use_exp2) print("cache_seqlens:", cache_seqlens) + print("block_table:", block_table) + print("paged_kv_block_size:", paged_kv_block_size) # cast to float32 q = q.to(torch.float32) - k = k.to(torch.float32) - v = v.to(torch.float32) - - # get seqlens - L_q, L_k = q.shape[1], k.shape[1] - # Compute attention scores - attention_scores = torch.matmul(q, k.transpose(-2, -1)) + # Check if we're in paged KV mode + is_paged = block_table is not None and paged_kv_block_size is not None + + if False: # Debug paged attention (disabled for production) + print(f"\n=== attention_forward_core_ref_impl DEBUG ===") + print(f"is_paged: {is_paged}") + print(f"block_table: {block_table.shape if block_table is not None else None}") + print(f"paged_kv_block_size: {paged_kv_block_size}") + if is_paged: + print(f"k shape (paged): {k.shape}") + print(f"v shape (paged): {v.shape}") + print(f"cache_seqlens: {cache_seqlens}") + + if is_paged: + # In paged mode, k and v are [num_blocks, block_size, nheads_k, head_dim] + # We'll compute attention on-the-fly without reconstructing + nheads_q = q.shape[0] + L_q = q.shape[1] + head_dim = q.shape[2] + + # Get number of KV heads from the cache + nheads_k = k.shape[2] # k shape: [num_blocks, block_size, nheads_k, head_dim] + + # Handle MQA/GQA + assert nheads_q % nheads_k == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_k ({nheads_k})" + group_size = nheads_q // nheads_k + + # Determine the actual KV sequence length from cache_seqlens + L_k = cache_seqlens if isinstance(cache_seqlens, int) else cache_seqlens.item() + + if False: # Debug disabled + print(f"L_q: {L_q}, L_k: {L_k}, nheads_q: {nheads_q}, nheads_k: {nheads_k}, group_size: {group_size}, head_dim: {head_dim}") + print(f"block_table contents: {block_table if block_table is not None else 'None'}") + + # Initialize attention scores + attention_scores = torch.zeros((nheads_q, L_q, L_k), dtype=torch.float32, device=q.device) + + # Compute attention scores on-the-fly by accessing blocks directly + for kv_pos in range(L_k): + # Calculate which block and position within block + block_idx = kv_pos // paged_kv_block_size + within_block_idx = kv_pos % paged_kv_block_size + + # Get the physical block number from block_table + # block_table is [1, num_blocks] for single batch in core function + if block_table.dim() == 2: + physical_block = block_table[0, block_idx].item() + else: + physical_block = block_table[block_idx].item() + + # Debug output disabled + # if kv_pos == 0: + # print(f"First KV access: block_idx={block_idx}, within_block={within_block_idx}, physical_block={physical_block}") + # print(f"k_vec shape will be: {k[physical_block, within_block_idx, :, :].shape}") + + # Access k values directly from paged cache + # k shape: [num_blocks, block_size, nheads_k, head_dim] + k_vec = k[physical_block, within_block_idx, :, :].to(torch.float32) # [nheads_k, head_dim] + + # For GQA/MQA, we need to repeat k_vec for each group + if group_size > 1: + # Expand k_vec to match query heads + # k_vec: [nheads_k, head_dim] -> [nheads_q, head_dim] + k_vec = k_vec.repeat_interleave(group_size, dim=0) + + # Compute dot product with all query positions + # q is [nheads_q, L_q, head_dim], k_vec is [nheads_q, head_dim] + # Result should be [nheads_q, L_q] for this kv_pos + attention_scores[:, :, kv_pos] = torch.sum(q * k_vec.unsqueeze(1), dim=-1) + + # Keep k and v in original format for later v computation + k_paged = k + v_paged = v + + # Debug output disabled + # print(f"attention_scores computed shape: {attention_scores.shape}") + # print(f"attention_scores sample values: {attention_scores[0, 0, :5]}") + else: + # Standard non-paged mode + k = k.to(torch.float32) + v = v.to(torch.float32) + + # get seqlens + L_q, L_k = q.shape[1], k.shape[1] + + # Compute attention scores + attention_scores = torch.matmul(q, k.transpose(-2, -1)) if DEBUG_CORE: print("attention_scores:", attention_scores, attention_scores.shape) @@ -256,7 +338,55 @@ def attention_forward_core_ref_impl( print("softmax_lse:", softmax_lse, softmax_lse.shape) # Compute output - o = torch.matmul(p, v) + if is_paged: + # Compute output on-the-fly using paged v cache + nheads_q = p.shape[0] + L_q = p.shape[1] + nheads_v = v_paged.shape[2] # [num_blocks, block_size, nheads_v, head_dim] + head_dim = v_paged.shape[3] + + # Handle MQA/GQA for v + assert nheads_q % nheads_v == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_v ({nheads_v})" + v_group_size = nheads_q // nheads_v + + o = torch.zeros((nheads_q, L_q, head_dim), dtype=torch.float32, device=p.device) + + # Accumulate weighted v values + for kv_pos in range(L_k): + # Calculate which block and position within block + block_idx = kv_pos // paged_kv_block_size + within_block_idx = kv_pos % paged_kv_block_size + + # Get the physical block number from block_table + if block_table.dim() == 2: + physical_block = block_table[0, block_idx].item() + else: + physical_block = block_table[block_idx].item() + + # Access v values directly from paged cache + # v_paged shape: [num_blocks, block_size, nheads_v, head_dim] + v_vec = v_paged[physical_block, within_block_idx, :, :].to(torch.float32) # [nheads_v, head_dim] + + # For GQA/MQA, we need to repeat v_vec for each group + if v_group_size > 1: + # Expand v_vec to match query heads + # v_vec: [nheads_v, head_dim] -> [nheads_q, head_dim] + v_vec = v_vec.repeat_interleave(v_group_size, dim=0) + + # Weight by attention probabilities + # p is [nheads_q, L_q, L_k], we need p[:, :, kv_pos] which is [nheads_q, L_q] + # v_vec is [nheads_q, head_dim] + # We want to add p[:, :, kv_pos] * v_vec to each query position + weights = p[:, :, kv_pos].unsqueeze(-1) # [nheads_q, L_q, 1] + o += weights * v_vec.unsqueeze(1) # [nheads_q, L_q, head_dim] + else: + o = torch.matmul(p, v) + + # Debug output disabled + # if False: + # print(f"Output o shape: {o.shape}") + # print(f"Output o sample values: {o[0, 0, :5]}") + if DEBUG_CORE: print("o:", o, o.shape) @@ -439,7 +569,7 @@ def attention_varlen_forward_pytorch_ref_impl( return o, softmax_lse, sd_mask -def attention_forward_pytorch_ref_impl( +def attention_prefill_forward_ref_impl( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -516,12 +646,61 @@ def attention_decode_forward_ref_impl( layout: Literal["bshd"], cache_seqlens: Optional[torch.Tensor], cache_batch_idx: Optional[torch.Tensor], + block_table: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, ): """Compute reference output for decode attention using PyTorch's built-in functions""" + if False: # Permanently disabled old debug output + pass + # print("\n========== attention_decode_forward_ref_impl inputs ==========") + # print(f"q shape: {q.shape}, dtype: {q.dtype}, device: {q.device}") + print(f"q values:\n{q}") + print(f"\nk_cache shape: {k_cache.shape}, dtype: {k_cache.dtype}, device: {k_cache.device}") + print(f"k_cache values:\n{k_cache}") + print(f"\nv_cache shape: {v_cache.shape}, dtype: {v_cache.dtype}, device: {v_cache.device}") + print(f"v_cache values:\n{v_cache}") + print(f"\nk_new: {k_new.shape if k_new is not None else None}, dtype: {k_new.dtype if k_new is not None else None}") + if k_new is not None: + print(f"k_new values:\n{k_new}") + print(f"\nv_new: {v_new.shape if v_new is not None else None}, dtype: {v_new.dtype if v_new is not None else None}") + if v_new is not None: + print(f"v_new values:\n{v_new}") + print(f"\nout shape: {out.shape}, dtype: {out.dtype}, device: {out.device}") + print(f"out values:\n{out}") + print(f"\nsm_scale: {sm_scale}") + print(f"causal: {causal}") + print(f"window_size_left: {window_size_left}") + print(f"window_size_right: {window_size_right}") + print(f"\nalibi_slopes: {alibi_slopes.shape if alibi_slopes is not None else None}, dtype: {alibi_slopes.dtype if alibi_slopes is not None else None}") + if alibi_slopes is not None: + print(f"alibi_slopes values:\n{alibi_slopes}") + print(f"\nlayout: {layout}") + print(f"cache_seqlens: {cache_seqlens}") + if cache_seqlens is not None and torch.is_tensor(cache_seqlens): + print(f"cache_seqlens values: {cache_seqlens}") + print(f"cache_batch_idx: {cache_batch_idx}") + if cache_batch_idx is not None: + print(f"cache_batch_idx values: {cache_batch_idx}") + print(f"\nblock_table: {block_table.shape if block_table is not None else None}, dtype: {block_table.dtype if block_table is not None else None}") + if block_table is not None: + print(f"block_table values:\n{block_table}") + print("=" * 60) + # get batch size before any layout conversion batch_size = q.shape[0] + # Determine if we're in paged KV mode + is_paged = block_table is not None + if is_paged: + # Infer block size from cache shape + # k_cache shape for paged: [num_blocks, block_size, nheads, head_dim] + paged_kv_block_size = k_cache.shape[1] + else: + paged_kv_block_size = None + # handle cache_batch_idx if cache_batch_idx is not None: # remap batch indices for cache access @@ -531,39 +710,79 @@ def attention_decode_forward_ref_impl( # copy new keys and values into cache if provided (before any layout conversion) if k_new is not None and v_new is not None: - _, seq_len_new, _, _ = k_new.shape # shape is [batch, seq_len, nheads, head_dim] for bshd layout - - for b in range(batch_size): - cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices + if is_paged: + # For paged KV cache, we need to update the blocks with new k/v values + _, seq_len_new, _, _ = k_new.shape # shape is [batch, seq_len, nheads, head_dim] for bshd layout - # determine where to place new k/v in cache - if cache_seqlens is not None: - if torch.is_tensor(cache_seqlens): - start_pos = cache_seqlens[b].item() + for b in range(batch_size): + # Determine where to place new k/v in cache + if cache_seqlens is not None: + if torch.is_tensor(cache_seqlens): + start_pos = cache_seqlens[b].item() + else: + start_pos = cache_seqlens else: - start_pos = cache_seqlens - else: - # if no cache_seqlens, assume we're filling from the beginning - start_pos = 0 - - end_pos = start_pos + seq_len_new + start_pos = 0 + + # For each new position, find the corresponding block and update it + for pos_offset in range(seq_len_new): + kv_pos = start_pos + pos_offset + + # Calculate which block and position within block + block_idx = kv_pos // paged_kv_block_size + within_block_idx = kv_pos % paged_kv_block_size + + # Get the physical block number from block_table + physical_block = block_table[b, block_idx].item() + + # Update the k and v values in the paged cache + # k_cache shape: [num_blocks, block_size, nheads, head_dim] + # k_new shape: [batch, seq_len, nheads, head_dim] + k_cache[physical_block, within_block_idx, :, :] = k_new[b, pos_offset, :, :] + v_cache[physical_block, within_block_idx, :, :] = v_new[b, pos_offset, :, :] + else: + _, seq_len_new, _, _ = k_new.shape # shape is [batch, seq_len, nheads, head_dim] for bshd layout - # copy new keys and values into cache (both are in bshd layout) - k_cache[cache_idx, start_pos:end_pos, :, :] = k_new[b, :, :, :] - v_cache[cache_idx, start_pos:end_pos, :, :] = v_new[b, :, :, :] + for b in range(batch_size): + cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices + + # determine where to place new k/v in cache + if cache_seqlens is not None: + if torch.is_tensor(cache_seqlens): + start_pos = cache_seqlens[b].item() + else: + start_pos = cache_seqlens + else: + # if no cache_seqlens, assume we're filling from the beginning + start_pos = 0 + + end_pos = start_pos + seq_len_new + + # copy new keys and values into cache (both are in bshd layout) + k_cache[cache_idx, start_pos:end_pos, :, :] = k_new[b, :, :, :] + v_cache[cache_idx, start_pos:end_pos, :, :] = v_new[b, :, :, :] # ensure the layout is 'bhsd' if layout == "bshd": q = q.transpose(1, 2).contiguous() - k_cache = k_cache.transpose(1, 2).contiguous() - v_cache = v_cache.transpose(1, 2).contiguous() + if not is_paged: + k_cache = k_cache.transpose(1, 2).contiguous() + v_cache = v_cache.transpose(1, 2).contiguous() elif layout != "bhsd": raise ValueError(f"Unknown layout {layout}") # prepare tensors batch_size_q, nheads_q, seq_len_q, head_dim = q.shape - batch_size_cache, nheads_k, max_cache_len, head_dim_k = k_cache.shape - _, nheads_v, _, head_dim_v = v_cache.shape + + if is_paged: + # For paged cache: [num_blocks, block_size, nheads, head_dim] + num_blocks, block_size, nheads_k, head_dim_k = k_cache.shape + _, _, nheads_v, head_dim_v = v_cache.shape + max_cache_len = None # Not directly available in paged mode + batch_size_cache = None # Not applicable in paged mode + else: + batch_size_cache, nheads_k, max_cache_len, head_dim_k = k_cache.shape + _, nheads_v, _, head_dim_v = v_cache.shape # validate dimensions assert head_dim == head_dim_k == head_dim_v, f"Head dimensions must match: {head_dim}, {head_dim_k}, {head_dim_v}" @@ -586,7 +805,8 @@ def attention_decode_forward_ref_impl( # process each batch element for b in range(batch_size): - cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices + if not is_paged: + cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices # determine valid cache length for this batch element if cache_seqlens is not None: @@ -601,25 +821,41 @@ def attention_decode_forward_ref_impl( _, seq_len_new, _, _ = k_new.shape cache_len += seq_len_new else: - cache_len = max_cache_len - - # CHANGE: Extract the full cache, not just valid portion - # This matches what the test does - it uses full k_cache_rep/v_cache_rep - k_b = k_cache[cache_idx, :, :, :] # [nheads_k, max_cache_len, head_dim] - v_b = v_cache[cache_idx, :, :, :] # [nheads_v, max_cache_len, head_dim] - q_b = q[b:b+1, :, :, :] # [1, nheads_q, seq_len_q, head_dim] + if is_paged: + # For paged mode, we need cache_seqlens to know the valid length + raise ValueError("cache_seqlens must be provided for paged KV cache") + else: + cache_len = max_cache_len - # handle MQA/GQA by expanding k and v - if group_size != 1: - # expand k and v to match q's number of heads - k_b = k_b.unsqueeze(1).expand(-1, group_size, -1, -1) - k_b = k_b.reshape(nheads_q, max_cache_len, head_dim) + if is_paged: + # For paged KV cache, pass the cache and block table directly + # Extract block table for this batch element + block_table_b = block_table[b:b+1, :] # [1, num_blocks] + k_b = k_cache # Pass entire paged cache + v_b = v_cache # Pass entire paged cache + q_b = q[b:b+1, :, :, :] # [1, nheads_q, seq_len_q, head_dim] - v_b = v_b.unsqueeze(1).expand(-1, group_size, -1, -1) - v_b = v_b.reshape(nheads_q, max_cache_len, head_dim) - - # reshape for attention_forward_core_ref_impl - q_b = q_b.reshape(nheads_q, seq_len_q, head_dim) + # For paged mode with MQA/GQA, we handle expansion in the core function + # Just reshape q for now + q_b = q_b.reshape(nheads_q, seq_len_q, head_dim) + else: + # Standard non-paged mode + k_b = k_cache[cache_idx, :, :, :] # [nheads_k, max_cache_len, head_dim] + v_b = v_cache[cache_idx, :, :, :] # [nheads_v, max_cache_len, head_dim] + q_b = q[b:b+1, :, :, :] # [1, nheads_q, seq_len_q, head_dim] + block_table_b = None + + # handle MQA/GQA by expanding k and v + if group_size != 1: + # expand k and v to match q's number of heads + k_b = k_b.unsqueeze(1).expand(-1, group_size, -1, -1) + k_b = k_b.reshape(nheads_q, max_cache_len, head_dim) + + v_b = v_b.unsqueeze(1).expand(-1, group_size, -1, -1) + v_b = v_b.reshape(nheads_q, max_cache_len, head_dim) + + # reshape for attention_forward_core_ref_impl + q_b = q_b.reshape(nheads_q, seq_len_q, head_dim) # handle alibi slopes for this batch alibi_slopes_b = None @@ -635,6 +871,8 @@ def attention_decode_forward_ref_impl( dropout_p=0.0, philox_seed=None, philox_offset=None, alibi_slopes=alibi_slopes_b, use_exp2=True, cache_seqlens=cache_len, # Pass valid cache length + block_table=block_table_b, # Pass block table for paged mode + paged_kv_block_size=paged_kv_block_size, # Pass block size for paged mode ) # store outputs diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index a6b83f64365..3dc443abf67 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -5,7 +5,7 @@ from .bwd_prefill_fused_atomics import attention_prefill_backward_triton_fused_atomics_impl from .bwd_prefill_fused_no_atomics import attention_prefill_backward_triton_split_fused_no_atomics_impl from .fwd_decode import attention_decode_forward_triton_impl -from .fwd_ref import attention_forward_pytorch_ref_impl, attention_decode_forward_ref_impl +from .fwd_ref import attention_prefill_forward_ref_impl, attention_decode_forward_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl from .utils import DEBUG, USE_REF, MetaData, is_fp8 from einops import rearrange, repeat @@ -31,8 +31,7 @@ def fwd(q: torch.Tensor, gen_: Optional[torch.Tensor] = None, descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_o: Optional[torch.Tensor] = None + descale_v: Optional[torch.Tensor] = None ): if DEBUG: @@ -53,11 +52,12 @@ def fwd(q: torch.Tensor, print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) - print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) if is_fp8(q): assert out is not None, "fp8 output tensor should be passed in." - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"For fp8, you need to pass descale factors for q, k and v" + if (descale_q is None) or (descale_k is None) or (descale_v is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) else: out = torch.zeros_like(q) if out is None else out.zero_() @@ -93,7 +93,7 @@ def fwd(q: torch.Tensor, if USE_REF: if DEBUG: print("Using reference implementation") - softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + softmax_lse_ref, sd_mask_ref = attention_prefill_forward_ref_impl( q, k, v, @@ -140,16 +140,13 @@ def fwd(q: torch.Tensor, USE_EXP2, descale_q, descale_k, - descale_v, - descale_o) + descale_v) softmax_lse=softmax_lse_triton sd_mask=sd_mask_triton if DEBUG: print("flash_attn_triton_amd.py::fwd outputs") print("o:", out, out.shape) - if is_fp8(out): - print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) print("softmax_lse:", softmax_lse, softmax_lse.shape) print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) print("rng_state:", rng_state) @@ -405,8 +402,7 @@ def varlen_fwd( gen_: Optional[torch.Tensor] = None, descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_o: Optional[torch.Tensor] = None + descale_v: Optional[torch.Tensor] = None ): if DEBUG: @@ -432,7 +428,9 @@ def varlen_fwd( if is_fp8(q): assert out is not None, "fp8 output tensor should be passed in." - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"For fp8, you need to pass descale factors for q, k and v" + if (descale_q is None) or (descale_k is None) or (descale_v is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) else: out = torch.zeros_like(q) if out is None else out.zero_() @@ -468,7 +466,7 @@ def varlen_fwd( if USE_REF: if DEBUG: print("Using reference implementation") - softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + softmax_lse_ref, sd_mask_ref = attention_prefill_forward_ref_impl( q, k, v, @@ -515,8 +513,7 @@ def varlen_fwd( USE_EXP2, descale_q, descale_k, - descale_v, - descale_o) + descale_v) softmax_lse=softmax_lse_triton sd_mask=sd_mask_triton @@ -899,6 +896,7 @@ def fwd_kvcache( metadata.layout, metadata.cache_seqlens, metadata.cache_batch_idx, + block_table, ) softmax_lse=softmax_lse_ref else: @@ -919,6 +917,7 @@ def fwd_kvcache( metadata.layout, metadata.cache_seqlens, metadata.cache_batch_idx, + block_table, ) softmax_lse = softmax_lse_triton diff --git a/flash_attn/flash_attn_triton_amd/interface_fa_v3.py b/flash_attn/flash_attn_triton_amd/interface_fa_v3.py new file mode 100755 index 00000000000..be8e2d3cbeb --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/interface_fa_v3.py @@ -0,0 +1,660 @@ +import torch +import os +from .fwd_prefill import attention_prefill_forward_triton_impl +from .bwd_prefill_split import attention_prefill_backward_triton_split_impl +from .bwd_prefill_fused_atomics import attention_prefill_backward_triton_fused_atomics_impl +from .bwd_prefill_fused_no_atomics import attention_prefill_backward_triton_split_fused_no_atomics_impl +from .fwd_decode import attention_decode_forward_triton_impl +from .fwd_ref import attention_prefill_forward_ref_impl, attention_decode_forward_ref_impl +from .bwd_ref import attention_backward_pytorch_ref_impl +from .utils import DEBUG, USE_REF, MetaData, is_fp8 +from einops import rearrange, repeat +from flash_attn.layers.rotary import apply_rotary_emb +from typing import Optional, Union, Tuple + +USE_EXP2 = True +BWD_MODE = os.environ.get('BWD_MODE', 'fused_no_atomics').lower() +USE_DECODE_PATH = os.environ.get('FLASH_ATTENTION_V3_USE_DECODE', '0') == '1' + +def fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + qv: Optional[torch.Tensor], + out: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + cu_seqlens_k_new: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + page_table: Optional[torch.Tensor], + kv_batch_idx: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + rotary_cos: Optional[torch.Tensor], + rotary_sin: Optional[torch.Tensor], + seqlens_rotary: Optional[torch.Tensor], + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + attention_chunk: int, + softcap: float, + rotary_interleaved: bool, + scheduler_metadata=None, + num_splits: int = 1, + pack_gqa=None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Flash Attention v3 forward pass compatible interface for AMD Triton implementation. + + This function maps v3 parameters to the existing AMD Triton implementation. + """ + + if DEBUG: + print() + print("interface_fa_v3.py::fwd inputs") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("k_new:", k_new, k_new.shape if k_new is not None else None) + print("v_new:", v_new, v_new.shape if v_new is not None else None) + print("qv:", qv, qv.shape if qv is not None else None) + print("out:", out, out.shape if out is not None else None) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape if cu_seqlens_q is not None else None) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape if cu_seqlens_k is not None else None) + print("cu_seqlens_k_new:", cu_seqlens_k_new, cu_seqlens_k_new.shape if cu_seqlens_k_new is not None else None) + print("seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None) + print("seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("page_table:", page_table, page_table.shape if page_table is not None else None) + print("kv_batch_idx:", kv_batch_idx, kv_batch_idx.shape if kv_batch_idx is not None else None) + print("leftpad_k:", leftpad_k, leftpad_k.shape if leftpad_k is not None else None) + print("rotary_cos:", rotary_cos, rotary_cos.shape if rotary_cos is not None else None) + print("rotary_sin:", rotary_sin, rotary_sin.shape if rotary_sin is not None else None) + print("seqlens_rotary:", seqlens_rotary, seqlens_rotary.shape if seqlens_rotary is not None else None) + print("q_descale:", q_descale, q_descale.shape if q_descale is not None else None) + print("k_descale:", k_descale, k_descale.shape if k_descale is not None else None) + print("v_descale:", v_descale, v_descale.shape if v_descale is not None else None) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("attention_chunk:", attention_chunk) + print("softcap:", softcap) + print("rotary_interleaved:", rotary_interleaved) + print("scheduler_metadata:", scheduler_metadata) + print("num_splits:", num_splits) + print("pack_gqa:", pack_gqa) + print("sm_margin:", sm_margin) + + # Handle qv packed input + if qv is not None: + raise NotImplementedError("QV packed input is not yet supported in the AMD Triton backend") + + + # Handle softcap + if softcap != 0.0: + raise NotImplementedError(f"Softcap is not yet supported in the AMD Triton backend (got softcap={softcap}, expected 0.0)") + + # Handle attention_chunk + if attention_chunk != 0 and attention_chunk != 1: + raise NotImplementedError(f"attention_chunk is not yet supported in the AMD Triton backend (got attention_chunk={attention_chunk})") + + + # Handle scheduler metadata + if scheduler_metadata is not None: + raise NotImplementedError("Scheduler metadata is not yet supported in the AMD Triton backend") + + # Handle pack_gqa + if pack_gqa is not None and pack_gqa is not False: + raise NotImplementedError(f"pack_gqa is not yet supported in the AMD Triton backend (got pack_gqa={pack_gqa})") + + # Handle num_splits + if num_splits != 1: + raise NotImplementedError(f"Split attention (num_splits > 1) is not yet supported in the AMD Triton backend (got num_splits={num_splits})") + + # Handle sm_margin + if sm_margin != 0: + raise NotImplementedError(f"sm_margin is not yet supported in the AMD Triton backend (got sm_margin={sm_margin}, expected 0)") + + # Handle leftpad_k + if leftpad_k is not None: + raise NotImplementedError("Left padding (leftpad_k) is not yet supported in the AMD Triton backend") + + # Handle cu_seqlens_k_new + if cu_seqlens_k_new is not None: + raise NotImplementedError("cu_seqlens_k_new is not yet supported in the AMD Triton backend") + + # if seqlens_rotary is not None: + # raise NotImplementedError("seqlens_rotary is not yet supported in the AMD Triton backend") + + # Setup metadata + metadata = MetaData(sm_scale=softmax_scale) + + + # Handle variable length sequences first to determine layout + # Determine layout based on tensor dimensions and cu_seqlens presence + if cu_seqlens_q is not None: + # Q has variable length - check tensor dimensions to confirm + if len(q.shape) == 3: # [total_seqlen, nheads, head_dim] + metadata.layout = "thd" + metadata.varlen = True + metadata.cu_seqlens_q = cu_seqlens_q + metadata.max_seqlens_q = max_seqlen_q + + # K might be varlen or batch mode + if cu_seqlens_k is not None: + metadata.cu_seqlens_k = cu_seqlens_k + metadata.max_seqlens_k = max_seqlen_k + else: + # K is in batch mode while Q is varlen (KV cache scenario) + metadata.cu_seqlens_k = None + metadata.max_seqlens_k = k.shape[1] if len(k.shape) == 4 else max_seqlen_k + else: + raise ValueError(f"cu_seqlens_q provided but q has shape {q.shape}, expected 3D tensor for varlen") + else: + # Regular batch mode + metadata.layout = "bshd" + metadata.varlen = False + metadata.cu_seqlens_q = None + metadata.cu_seqlens_k = None + metadata.max_seqlens_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q + metadata.max_seqlens_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k + + # Now determine if we should use decode or prefill kernel + # Decode kernel should be used for KV cache scenarios where: + # 1. k_new/v_new are provided - incremental KV cache update (primary KV cache indicator) + # 2. kv_batch_idx is provided - KV cache batch indexing (primary KV cache indicator) + # 3. seqused_k without seqused_q - indicates KV cache fill levels (not varlen masking) + # Note: In varlen, both seqused_q and seqused_k are used for sequence masking + # In KV cache, only seqused_k is used to track cache fill levels + if USE_DECODE_PATH: + # Force decode path + use_decode = True + else: + # Detect KV cache scenarios: + # - Clear KV cache indicators (k_new, v_new, kv_batch_idx) + # - OR seqused_k without seqused_q (KV cache fill tracking, not varlen masking) + use_decode = ( + k_new is not None or # Have new KV to append (KV cache indicator) + v_new is not None or # Have new KV to append (KV cache indicator) + kv_batch_idx is not None or # Have KV cache batch indexing (KV cache indicator) + (seqused_k is not None and seqused_q is None) # KV cache fill levels (not varlen) + ) + + # Check for unsupported features with decode kernel + if use_decode: + if metadata.layout == "thd": + raise NotImplementedError("Varlen is not yet supported with the decode kernel in the AMD Triton backend") + if kv_batch_idx is not None: + raise NotImplementedError("kv_batch_idx is not yet supported with the decode kernel in the AMD Triton backend") + + + if out is None: + out_dtype = torch.float32 if is_fp8(q) else q.dtype + if metadata.layout == "bshd": + out = torch.zeros(q.shape[0], q.shape[1], q.shape[2], v.shape[-1], dtype=out_dtype, device=q.device) + elif metadata.layout == "thd": + out = torch.zeros(q.shape[0], q.shape[1], v.shape[-1], dtype=out_dtype, device=q.device) + else: + raise ValueError(f"Unsupported layout: {metadata.layout}. Only 'bshd' and 'thd' layouts are supported.") + else: + out = out.zero_() + + if is_fp8(q): + if (q_descale is None) or (k_descale is None) or (v_descale is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) + + # Get shape + if metadata.layout == "bshd": + batch, _, nheads_q, _ = q.shape + else: # "thd" layout for varlen + _, nheads_q, _ = q.shape + batch = len(cu_seqlens_q) - 1 if cu_seqlens_q is not None else 1 + + # Handle causal mask + if causal: + metadata.need_causal(True) + + # Handle alibi slopes (not directly supported in v3 interface, but we'll keep the logic) + alibi_slopes = None # V3 doesn't have alibi_slopes in the signature + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError(f"Alibi can be (nheads,) or (batch_size, nheads). Given tensor with shape {alibi_slopes.shape}") + metadata.need_alibi(alibi_slopes, batch, nheads_q) + + # Handle dropout (v3 doesn't have dropout in forward) + dropout_p = 0.0 + return_softmax = False + metadata.need_dropout(dropout_p, return_softmax) + + # Handle rotary embeddings + if rotary_cos is not None and rotary_sin is not None: + metadata.need_rotary(rotary_sin, rotary_cos, rotary_interleaved) + + # Apply rotary embeddings if provided + if metadata.causal or window_size_left != -1 or window_size_right != -1: + q_rot = apply_rotary_emb( + q, + rotary_cos, + rotary_sin, + seqlen_offsets=seqlens_rotary, + interleaved=rotary_interleaved, + ) + q = q_rot.to(q.dtype) + + if k_new is not None: + k_rot = apply_rotary_emb( + k_new, + rotary_cos, + rotary_sin, + seqlen_offsets=seqlens_rotary, + interleaved=rotary_interleaved, + ) + k_new = k_rot.to(k.dtype) + + # Store RNG state + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) + + # Call implementation + if USE_REF: + if DEBUG: + print("Using reference implementation") + + if use_decode: + if DEBUG: + print(f"Using decode reference implementation ( layout={metadata.layout}, cache_seqlens={seqused_k is not None}, k_new={k_new is not None}, v_new={v_new is not None}, kv_batch_idx={kv_batch_idx is not None})") + # Use decode reference implementation + softmax_lse = attention_decode_forward_ref_impl( + q, + k, # k_cache + v, # v_cache + k_new, + v_new, + out, + metadata.sm_scale, + metadata.causal, + window_size_left, + window_size_right, + metadata.alibi_slopes, + metadata.layout, + seqused_k, # cache_seqlens + kv_batch_idx, # cache_batch_idx + page_table, # block_table + q_descale, + k_descale, + v_descale, + ) + else: + if DEBUG: + print("Using prefill reference implementation") + # Use prefill reference implementation + softmax_lse_ref, sd_mask_ref = attention_prefill_forward_ref_impl( + q, k, v, out, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + window_size_left, + window_size_right, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + USE_EXP2 + ) + softmax_lse = softmax_lse_ref + else: + if DEBUG: + print("Using Triton implementation") + + if use_decode: + if DEBUG: + print(f"Using Decode Triton implementation (cache_seqlens={seqused_k is not None}, k_new={k_new is not None}, v_new={v_new is not None}, kv_batch_idx={kv_batch_idx is not None})") + + # Use decode kernel for KV cache scenarios + # Note: seqused_k can serve as cache_seqlens in v3 + softmax_lse = attention_decode_forward_triton_impl( + q, + k, # k_cache in v2 terminology + v, # v_cache in v2 terminology + k_new, # New KV values to append to cache + v_new, # New KV values to append to cache + out, + metadata.sm_scale, + metadata.causal, + window_size_left, + window_size_right, + metadata.alibi_slopes, + metadata.layout, + seqused_k, # cache_seqlens + kv_batch_idx, # cache_batch_idx + page_table, # block_table for paged attention + q_descale, + k_descale, + v_descale, + ) + # Decode kernel returns only softmax_lse, not sd_mask + sd_mask_triton = None + else: + if DEBUG: + print("Using prefill Triton implementation") + # Use prefill kernel + softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( + q, k, v, out, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + window_size_left, + window_size_right, + None, # block_table + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.return_softmax, + USE_EXP2, + q_descale, + k_descale, + v_descale, + seqused_q, + seqused_k, + ) + softmax_lse = softmax_lse_triton + + if DEBUG: + print("interface_fa_v3.py::fwd outputs") + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + + # Return format compatible with v3 + # V3 returns (out, softmax_lse, *rest) where rest can be empty or contain additional outputs + return out, softmax_lse + + +def bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Flash Attention v3 backward pass compatible interface for AMD Triton implementation. + + This function maps v3 parameters to the existing AMD Triton implementation. + """ + + if DEBUG: + print() + print("interface_fa_v3.py::bwd inputs") + print("dout:", dout, dout.shape) + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("dq:", dq, dq.shape if dq is not None else None) + print("dk:", dk, dk.shape if dk is not None else None) + print("dv:", dv, dv.shape if dv is not None else None) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape if cu_seqlens_q is not None else None) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape if cu_seqlens_k is not None else None) + print("seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None) + print("seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("deterministic:", deterministic) + print("sm_margin:", sm_margin) + + # Check for unsupported features in backward pass + + # Handle softcap + if softcap != 0.0: + raise NotImplementedError(f"Softcap is not yet supported in the AMD Triton backend backward pass (got softcap={softcap}, expected 0.0)") + + # Handle sm_margin + if sm_margin != 0: + raise NotImplementedError(f"sm_margin is not yet supported in the AMD Triton backend backward pass (got sm_margin={sm_margin}, expected 0)") + + # Initialize gradient tensors if not provided + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() + + # Determine layout based on cu_seqlens + if cu_seqlens_q is not None and cu_seqlens_k is not None: + # Variable length sequence mode + layout = "thd" + batch = len(cu_seqlens_q) - 1 + _, nheads_q, _ = q.shape + else: + # Regular batch mode + layout = "bshd" + batch, _, nheads_q, _ = q.shape + max_seqlen_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k + + # V3 backward doesn't have dropout or alibi slopes + dropout_p = 0.0 + philox_seed, philox_offset = None, None + alibi_slopes = None + + # For fp8, we would need descale factors, but v3 interface doesn't expose them + # So we'll pass None for now + descale_q = None + descale_k = None + descale_v = None + descale_o = None + descale_do = None + descale_dq = None + descale_dk = None + descale_dv = None + + # Call implementation + if USE_REF: + if DEBUG: + print("Using reference implementation") + delta_ref = attention_backward_pytorch_ref_impl( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, + softmax_scale, + alibi_slopes, + causal, + window_size_left, + window_size_right, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + USE_EXP2, + ) + delta = delta_ref + else: + if DEBUG: + print("Using Triton implementation") + + if BWD_MODE == "split": + delta_triton = attention_prefill_backward_triton_split_impl( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, + softmax_scale, + alibi_slopes, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + USE_EXP2, + descale_q, descale_k, descale_v, descale_o, + descale_do, descale_dq, descale_dk, descale_dv, + seqused_q, seqused_k, + ) + delta = delta_triton + elif BWD_MODE == "fused_atomics": + delta_triton = attention_prefill_backward_triton_fused_atomics_impl( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, + softmax_scale, + alibi_slopes, + causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + descale_q, descale_k, descale_v, descale_o, + True, + ) + delta = delta_triton + elif BWD_MODE == "fused_no_atomics": + delta_triton = attention_prefill_backward_triton_split_fused_no_atomics_impl( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, + softmax_scale, + alibi_slopes, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + USE_EXP2, + descale_q, descale_k, descale_v, descale_o, + descale_do, descale_dq, descale_dk, descale_dv, + seqused_q, seqused_k, + ) + delta = delta_triton + else: + raise ValueError(f"Unknown bwd mode {BWD_MODE}") + + if DEBUG: + print("interface_fa_v3.py::bwd outputs") + print("dq:", dq, dq.shape) + print("dk:", dk, dk.shape) + print("dv:", dv, dv.shape) + print("delta:", delta, delta.shape if delta is not None else None) + + # V3 expects (dq, dk, dv, softmax_d, *rest) + # delta is the softmax_d in this case + return dq, dk, dv, delta + + +def fwd_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + Combine partial outputs from split attention computation. + + This is used when num_splits > 1 to combine the partial results. + + Args: + out_partial: Partial output tensor from split computation + lse_partial: Partial log-sum-exp tensor + out: Optional output tensor to write to + out_dtype: Optional dtype for output + + Returns: + Combined output tensor + """ + raise NotImplementedError("fwd_combine is not yet implemented in the AMD Triton backend") + + +def get_scheduler_metadata( + batch_size: int, + max_seqlen_q: int, + max_seqlen_k: int, + num_heads_q: int, + num_heads_kv: int, + headdim: int, + headdim_v: int, + qkv_dtype: torch.dtype, + cache_seqlens: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new: int = 0, + causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + has_softcap: bool = False, + num_splits: int = 0, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +): + """ + Get scheduler metadata for optimized kernel selection. + + This function is used to precompute metadata for kernel scheduling in FA3. + The AMD Triton backend currently doesn't use scheduler metadata, so this + raises an error. + + Args: + Various attention parameters used for scheduling decisions + + Returns: + None - scheduler metadata is not used in AMD Triton backend + """ + raise NotImplementedError("get_scheduler_metadata is not supported in the AMD Triton backend yet.") \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py deleted file mode 100644 index 15fa8f2f07a..00000000000 --- a/flash_attn/flash_attn_triton_amd/test.py +++ /dev/null @@ -1,777 +0,0 @@ -import os -import glob -import shutil -import time -import torch -import pytest -import logging -import numpy as np -from pathlib import Path -import triton -import flash_attn -import flash_attn.flash_attn_triton_amd.fp8 - -from .utils import generate_bshd_kv_packed, generate_bshd_qkv_packed, generate_bshd_tensor, generate_varlen_kv_packed, generate_varlen_qkv_packed, input_helper, arch_supports_fp8, generate_varlen_tensor - -# debugging - -logging.basicConfig(level=logging.INFO, format='%(message)s', force=True) -logger = logging.getLogger(__name__) -DEBUG = False - -# defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html -ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose. -# ATOL, RTOL = 1e-3, 1e-3 # catchs fa mismatch issues -# ATOL, RTOL = 1e-4, 1e-3 # to strict. there will be small diffs -# ATOL, RTOL = 1e-5, 1e-3 # # default fp16. there will be small diffs -# ATOL_fp8, RTOL_fp8 = 1e-1, 1e-1 # to strict for larger tensors in fp8 -ATOL_fp8, RTOL_fp8 = 2.5e-1, 2.5e-1 # fp8 -# ATOL_fp8, RTOL_fp8 = 2e-2, 2e-2 # fp8 -EQUAL_NAN = True - -def fp8_assert_close(tensor_a, tensor_b, atol=ATOL_fp8, rtol=RTOL_fp8, max_diff_percentage=0.5): - """Assert tensors are close with tolerance for small percentage of elements""" - # standard comparison - abs_diff = torch.abs(tensor_a - tensor_b) - rel_diff = abs_diff / torch.abs(tensor_b.clamp(min=1e-6)) - - # calculate elements that exceed tolerance - abs_check = abs_diff > atol - rel_check = rel_diff > rtol - failed_check = torch.logical_and(abs_check, rel_check) - - # calculate percentage of failed elements - failed_percentage = failed_check.sum().item() / failed_check.numel() * 100 - - # if percentage is small enough, test passes - if failed_percentage <= max_diff_percentage: - return True - - # Otherwise, provide diagnostic information - max_abs_idx = torch.argmax(abs_diff).item() - max_rel_idx = torch.argmax(rel_diff).item() - - flat_to_idx = lambda flat_idx, shape: np.unravel_index(flat_idx, shape) - - max_abs_pos = flat_to_idx(max_abs_idx, tensor_a.shape) - max_rel_pos = flat_to_idx(max_rel_idx, tensor_a.shape) - - max_abs_diff = abs_diff.flatten()[max_abs_idx].item() - max_rel_diff = rel_diff.flatten()[max_rel_idx].item() - - raise AssertionError( - f"Tensors not close enough! {failed_percentage:.6f}% elements exceed tolerance.\n" - f"Greatest absolute difference: {max_abs_diff} at index {max_abs_pos} (up to {atol} allowed)\n" - f"Greatest relative difference: {max_rel_diff} at index {max_rel_pos} (up to {rtol} allowed)" - ) - -@pytest.mark.parametrize( - "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - # seqlen q == k - (1, 1, 1, 1, 1, 1), - (1, 1, 1, 2, 2, 2), # small enough to debug - (1, 1, 1, 4, 4, 16), - (1, 2, 2, 4, 4, 16), - (2, 1, 1, 4, 4, 16), - (2, 2, 2, 4, 4, 16), - (1, 1, 1, 128, 128, 32), # only one block - (3, 3, 3, 128, 128, 64), - (1, 1, 1, 127, 127, 32), # only one block but with masking - # (1, 1, 1, 129, 129, 1), # two blocks with 2nd block small enough to debug # fails - (1, 2, 2, 129, 129, 32), # two blocks with 2nd block small enough to debug - (1, 1, 1, 350, 350, 32), # two blocks with 2nd block small enough to debug - (1, 1, 1, 350, 350, 68), # generic masking on q, k and head - (4, 1, 1, 512, 512, 128), # batch > 1 - (4, 2, 2, 512, 512, 128), - (4, 2, 2, 512, 512, 68), - (4, 2, 2, 500, 500, 68), - (2, 4, 4, 1024, 1024, 64), - (4, 8, 8, 2048, 2048, 128), - (2, 8, 8, 4096, 4096, 64), - (2, 4, 4, 8192, 8192, 32), - # seqlen q > k - (1, 1, 1, 4, 2, 16), - (1, 1, 1, 64, 32, 8), - (1, 1, 1, 128, 64, 16), - (1, 1, 1, 192, 128, 32), - (1, 2, 2, 1024, 512, 68), - (1, 4, 4, 729, 516, 68), - (2, 4, 4, 2753, 1528, 68), # a comprehensive seqlen_q > seqlen_k - # seqlen q < k - (1, 1, 1, 2, 4, 16), - (1, 2, 2, 2, 4, 16), - (1, 4, 1, 2, 4, 16), - (1, 4, 2, 2, 4, 16), - (2, 2, 2, 2, 128, 1), - (2, 3, 3, 2, 128, 16), - (1, 1, 1, 32, 64, 8), - (1, 1, 1, 128, 192, 32), - (4, 6, 6, 108, 256, 32), - (3, 2, 2, 256, 512, 16), - (2, 2, 2, 512, 1024, 68), - (1, 1, 1, 200, 413, 32), - (1, 1, 1, 782, 1546, 32), - # gqa/mqa # mismatch issue on varlen - (4, 8, 2, 500, 500, 68), - (4, 8, 2, 512, 512, 68), - (4, 8, 2, 512, 512, 128), - (4, 8, 2, 512, 1024, 68), - (4, 8, 2, 1024, 512, 64), - (4, 16, 4, 1528, 2753, 68), - # fa configs - (2, 4, 1, 113, 203, 64), - (2, 4, 2, 128, 217, 64), - (2, 6, 2, 113, 211, 128), - (2, 6, 2, 108, 256, 128), - (2, 6, 2, 256, 512, 64), - (2, 6, 2, 512, 256, 64), - (2, 6, 2, 1024, 1024, 32), - (2, 6, 2, 1023, 1024, 32), - (2, 6, 6, 1024, 1023, 32), - (2, 6, 6, 2048, 2048, 32), - ], -) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('dropout_p', [0.0]) -@pytest.mark.parametrize('layout', ["bshd", "thd"]) -@pytest.mark.parametrize('packing', [None, "qkv"]) -@pytest.mark.parametrize('DEBUG_INPUT', [False]) -@pytest.mark.flaky(reruns=3, reason="Retry failures") -@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") -def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, packing, DEBUG_INPUT): - torch.manual_seed(20) - test_backward = True - device = "cuda" - window_size = (-1, -1) - softcap = 0.0 - alibi_slopes = None - deterministic = False - ref_dtype = torch.float32 - is_varlen = True if layout == "thd" else False - - # skip QKV packing tests for uneven sequence lengths and head sizes - if packing == 'qkv': - if N_CTX_Q != N_CTX_K: - pytest.skip("QKV packing requires N_CTX_Q == N_CTX_K") - if HQ != HK: - pytest.skip("QKV packing requires HQ == HK") - - # test apis - if packing == 'qkv': - # generate inputs - qkv, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, packing=packing, device=device) - - # ---------------------------------------------------------------- - # --- FP8 --- - # ---------------------------------------------------------------- - qkv_fp8 = qkv.clone() - do_fp8= do.clone() - - if is_varlen: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_qkvpacked_fp8_func( - qkv_fp8, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_qkvpacked_fp8_func( - qkv_fp8, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Reference --- - # ---------------------------------------------------------------- - # reference forward pass - qkv_ref = qkv.clone() - do_ref= do.clone() - - if is_varlen: - out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_varlen_qkvpacked_func( - qkv_ref, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_qkvpacked_func( - qkv_ref, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Compare --- - # ---------------------------------------------------------------- - # compare forward - if DEBUG: - print() - print(f"Compare fp8 against ref with dtype {ref_dtype}") - - if DEBUG: - print("out_ref:", out_ref, out_ref.shape) - print("out_fp8:", out_fp8, out_fp8.shape) - fp8_assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - - if DEBUG: - print("lse_ref:", lse_ref, lse_ref.shape) - print("lse_fp8:", lse_fp8, lse_fp8.shape) - fp8_assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - - if dropout_p > 0.0: - if DEBUG: - print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape) - print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape) - fp8_assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - if not test_backward: - return - - # fp8 backward pass - dqkv_fp8, = torch.autograd.grad(out_fp8, (qkv_fp8), do_fp8) - - # ref backward pass - dqkv_ref, = torch.autograd.grad(out_ref, (qkv_ref), do_ref) - - # compare backward gradients - if DEBUG: - print("dqkv_ref:", dqkv_ref, dqkv_ref.shape) - print("dqkv_fp8:", dqkv_fp8, dqkv_fp8.shape) - fp8_assert_close(dqkv_ref, dqkv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - elif packing is None: - # generate inputs - q, k, v, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, device=device) - - # ---------------------------------------------------------------- - # --- FP8 --- - # ---------------------------------------------------------------- - if DEBUG: - print() - print(f"Compute Fp8 Forward") - q_fp8 = q.clone() - k_fp8 = k.clone() - v_fp8 = v.clone() - do_fp8= do.clone() - - if is_varlen: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_fp8_func( - q_fp8, - k_fp8, - v_fp8, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_fp8_func( - q_fp8, - k_fp8, - v_fp8, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Reference --- - # ---------------------------------------------------------------- - if DEBUG: - print() - print(f"Compute Reference Forward") - # reference forward pass - q_ref = q.clone() - k_ref = k.clone() - v_ref = v.clone() - do_ref = do.clone() - - if is_varlen: - out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_varlen_func( - q_ref, - k_ref, - v_ref, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_func( - q_ref, - k_ref, - v_ref, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Compare --- - # ---------------------------------------------------------------- - # compare forward - if DEBUG: - print() - print(f"Compare fp8 against ref with dtype {ref_dtype}") - - if DEBUG: - print("out_ref:", out_ref, out_ref.shape) - print("out_fp8:", out_fp8, out_fp8.shape) - # torch.testing.assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - fp8_assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - - if DEBUG: - print("lse_ref:", lse_ref, lse_ref.shape) - print("lse_fp8:", lse_fp8, lse_fp8.shape) - # torch.testing.assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - fp8_assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - - if dropout_p > 0.0: - if DEBUG: - print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape) - print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape) - # torch.testing.assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - fp8_assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - if not test_backward: - return - - if DEBUG: - print() - print(f"Compute Fp8 Backward") - # fp8 backward pass - dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8) - - if DEBUG: - print() - print(f"Compute Reference Backward") - # ref backward pass - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), do_ref) - - # compare backward gradients - if DEBUG: - print("dv_ref:", dv_ref, dv_ref.shape) - print("dv_fp8:", dv_fp8, dv_fp8.shape) - # torch.testing.assert_close(dv_ref, dv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) - fp8_assert_close(dv_ref, dv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - if DEBUG: - print("dk_ref:", dk_ref, dk_ref.shape) - print("dk_fp8:", dk_fp8, dk_fp8.shape) - # torch.testing.assert_close(dk_ref, dk_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) - fp8_assert_close(dk_ref, dk_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - if DEBUG: - print("dq_ref:", dq_ref, dq_ref.shape) - print("dq_fp8:", dq_fp8, dq_fp8.shape) - # torch.testing.assert_close(dq_ref, dq_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) - fp8_assert_close(dq_ref, dq_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - -@pytest.mark.parametrize( - "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - (2, 4, 4, 512, 512, 128), - ], -) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('dropout_p', [0.0, 0.1]) -@pytest.mark.parametrize('layout', ['bshd']) -@pytest.mark.parametrize('packing', [None]) -@pytest.mark.parametrize('test_backward', [False, True]) -@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") -@pytest.mark.skip("Breaks on CI but works locally") -def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, packing, test_backward): # Don't run this test in parallel. It clears the cache so it doesnot work properly if run in parallel. - torch.manual_seed(20) - device = "cuda" - window_size = (-1, -1) - softcap = 0.0 - alibi_slopes = None - deterministic = False - ref_dtype = torch.float32 - is_varlen = True if layout == "thd" else False - - # remove cache - cache_path = Path(os.path.expanduser("~/.triton/cache")) - if cache_path.exists(): - shutil.rmtree(cache_path) - os.makedirs(cache_path) - - # inputs - q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout=layout, packing=packing, device=device) - - if packing == None: - # fp8 forward pass - if is_varlen: - out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_fp8_func( - q, - k, - v, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_fp8_func( - q, - k, - v, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # fp8 backward pass - if test_backward: - dq, dk, dv = torch.autograd.grad(out, (q, k, v), do) - elif packing == "qkv": - # qkv packing path - # pack input tensors (use dim=1 for varlen, else dim=2) - if is_varlen: - qkv = torch.stack([q, k, v], dim=1) - else: - qkv = torch.stack([q, k, v], dim=2) - - # fp8 forward pass for qkv-packed input - if is_varlen: - out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_qkvpacked_fp8_func( - qkv, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_qkvpacked_fp8_func( - qkv, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # fp8 backward pass for qkv-packed input - if test_backward: - dqkv, = torch.autograd.grad(out, (qkv,), do) - else: - raise ValueError(f"unknown packing type {packing}") - - # search for .ttir files - max_retries = 5 - retry_delay = 0.5 - ttir_files = [] - logging.info(f"Checking for .ttir files in {cache_path}...") - for attempt in range(max_retries): - # search for .ttir files recursively within the cache path - ttir_files = glob.glob(str(cache_path) + "/**/*.ttir", recursive=True) - - if ttir_files: - # Files found, log success and exit the loop - logging.info(f"Found {len(ttir_files)} .ttir files on attempt {attempt + 1}.") - break - else: - # Files not found yet - if attempt < max_retries - 1: - # If not the last attempt, wait and log before retrying - logging.warning( - f"No .ttir files found on attempt {attempt + 1}. " - f"Retrying in {retry_delay}s..." - ) - time.sleep(retry_delay) - else: - pytest.fail( - f"FATAL: No .ttir files found in cache {cache_path} " - f"after {max_retries} attempts." - ) - - # check if there is fp8 - ttir_files_fp8_found_status = {} - fp8_types = ['f8E4M3', 'f8E5M2'] - for ttir_file in ttir_files: - base_name = os.path.basename(ttir_file) - with open(ttir_file, 'r') as f: - content = f.read() - - # check content for fp8 - fp8_found = False - for f8_type in fp8_types: - if f8_type in content: - fp8_found = True - ttir_files_fp8_found_status[base_name] = fp8_found - - for file, fp8_found in ttir_files_fp8_found_status.items(): - assert fp8_found, f"{fp8_types} not found in {file}" - - -def clear_compile_cache(): - """Clear torch compile caches to prevent graph merging""" - if hasattr(torch._dynamo, 'reset'): - torch._dynamo.reset() - torch.cuda.synchronize() - - -@pytest.mark.parametrize( - "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - # (4, 8, 8, 128, 128, 64), # small test - (32, 32, 32, 531, 531, 128), # original test - # (16, 48, 16, 256, 512, 64), # MQA test (HQ > HK) - ], -) -@pytest.mark.skip() # this works or not depending on the torch and triton version compatibility. Keep the test but use it in docker containers where things are in sync -def test_torch_compile(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD): - print(f"\n\nTesting with BATCH={BATCH}, HQ={HQ}, HK={HK}, N_CTX_Q={N_CTX_Q}, N_CTX_K={N_CTX_K}, D_HEAD={D_HEAD}") - - try: - # Test 1: flash_attn_func - print("\n1. Testing flash_attn_func...") - clear_compile_cache() - - q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD) - k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD) - v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD) - - flash_attn_func_compiled = torch.compile(flash_attn.flash_attn_func) - o = flash_attn_func_compiled(q, k, v, causal=True) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_func SUCCESS") - - # cleanup - del q, k, v, o - torch.cuda.empty_cache() - - - # Test 2: flash_attn_varlen_func - print("\n2. Testing flash_attn_varlen_func...") - clear_compile_cache() - - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(BATCH * N_CTX_Q, HQ, D_HEAD, BATCH) - k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(BATCH * N_CTX_K, HK, D_HEAD, BATCH) - v, _, _ = generate_varlen_tensor(BATCH * N_CTX_K, HK, D_HEAD, BATCH) - - flash_attn_varlen_func_compiled = torch.compile(flash_attn.flash_attn_varlen_func) - o = flash_attn_varlen_func_compiled( - q, k, v, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, causal=True - ) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_varlen_func SUCCESS") - - # cleanup - del q, k, v, o, cu_seqlens_q, cu_seqlens_k - torch.cuda.empty_cache() - - - # Test 3: flash_attn_qkvpacked_func - print("\n3. Testing flash_attn_qkvpacked_func...") - clear_compile_cache() - - qkv = generate_bshd_qkv_packed(BATCH, N_CTX_Q, HQ, D_HEAD) - - flash_attn_qkvpacked_func_compiled = torch.compile(flash_attn.flash_attn_qkvpacked_func) - o = flash_attn_qkvpacked_func_compiled(qkv, causal=True) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_qkvpacked_func SUCCESS") - - # cleanup - del qkv, o - torch.cuda.empty_cache() - - - # Test 4: flash_attn_varlen_qkvpacked_func - print("\n4. Testing flash_attn_varlen_qkvpacked_func...") - clear_compile_cache() - - total_q = BATCH * N_CTX_Q - qkv, cu_seqlens, max_seqlen = generate_varlen_qkv_packed(total_q, HQ, D_HEAD, BATCH) - - flash_attn_varlen_qkvpacked_func_compiled = torch.compile(flash_attn.flash_attn_varlen_qkvpacked_func) - o = flash_attn_varlen_qkvpacked_func_compiled( - qkv, cu_seqlens, max_seqlen, causal=True - ) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_varlen_qkvpacked_func SUCCESS") - - # cleanup - del qkv, o, cu_seqlens - torch.cuda.empty_cache() - - - # Test 5: flash_attn_kvpacked_func - print("\n5. Testing flash_attn_kvpacked_func...") - clear_compile_cache() - - q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD) - kv = generate_bshd_kv_packed(BATCH, N_CTX_K, HK, D_HEAD) - - flash_attn_kvpacked_func_compiled = torch.compile(flash_attn.flash_attn_kvpacked_func) - o = flash_attn_kvpacked_func_compiled(q, kv, causal=True) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_kvpacked_func SUCCESS") - - # cleanup - del q, kv, o - torch.cuda.empty_cache() - - - # Test 6: flash_attn_varlen_kvpacked_func - print("\n6. Testing flash_attn_varlen_kvpacked_func...") - clear_compile_cache() - - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(BATCH * N_CTX_Q, HQ, D_HEAD, BATCH) - kv, cu_seqlens_k, max_seqlen_k = generate_varlen_kv_packed(BATCH * N_CTX_K, HK, D_HEAD, BATCH) - - flash_attn_varlen_kvpacked_func_compiled = torch.compile(flash_attn.flash_attn_varlen_kvpacked_func) - o = flash_attn_varlen_kvpacked_func_compiled( - q, kv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, causal=True - ) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_varlen_kvpacked_func SUCCESS") - - # cleanup - del q, kv, o, cu_seqlens_q, cu_seqlens_k - torch.cuda.empty_cache() - - - # Test 7: flash_attn_with_kvcache - print("\n7. Testing flash_attn_with_kvcache...") - clear_compile_cache() - - # setup cache dimensions - CACHE_SEQLEN = 1024 # max cache size - NEW_SEQLEN = 1 # for incremental decoding, usually 1 token at a time - - # create query for new tokens - q = generate_bshd_tensor(BATCH, NEW_SEQLEN, HQ, D_HEAD, dtype=torch.float16) - - # create kv cache using generators - k_cache = generate_bshd_tensor(BATCH, CACHE_SEQLEN, HK, D_HEAD, dtype=torch.float16) - v_cache = generate_bshd_tensor(BATCH, CACHE_SEQLEN, HK, D_HEAD, dtype=torch.float16) - - # cache sequence lengths - cache_seqlens = torch.full((BATCH,), 100, dtype=torch.int32, device='cuda') - - # new k,v to append to cache (optional) - k_new = generate_bshd_tensor(BATCH, NEW_SEQLEN, HK, D_HEAD, dtype=torch.float16) - v_new = generate_bshd_tensor(BATCH, NEW_SEQLEN, HK, D_HEAD, dtype=torch.float16) - - # Note: flash_attn_with_kvcache doesn't support backward pass - flash_attn_with_kvcache_compiled = torch.compile(flash_attn.flash_attn_with_kvcache) - - # Test with new k,v (append to cache and do attention) - with torch.no_grad(): - o = flash_attn_with_kvcache_compiled( - q, k_cache, v_cache, - k=k_new, v=v_new, - cache_seqlens=cache_seqlens, - causal=True - ) - print(f"Output shape (with new kv): {o.shape}, dtype: {o.dtype}") - - print("✓ flash_attn_with_kvcache SUCCESS") - - print("\n\n✅ ALL TESTS PASSED! ✅") - - except Exception as e: - print(f"\n❌ ERROR: {str(e)}") - # ensure we sync even on error to get proper error message - torch.cuda.synchronize() - raise e - finally: - # final cleanup - torch.cuda.empty_cache() - clear_compile_cache() - -# log env -if os.environ.get('PYTEST_XDIST_WORKER') in (None, 'gw0'): - logger.info("\n" + "="*80) - logger.info("ENVIRONMENT INFORMATION") - logger.info("="*80) - - # triton - logger.info(f" Triton Version: {triton.__version__}") - - # torch - logger.info(f" PyTorch Version: {torch.__version__}") - - # flash attention - logger.info(f" Flash Attention Version: {flash_attn.__version__}") - - logger.info("="*80 + "\n") diff --git a/flash_attn/flash_attn_triton_amd/train.py b/flash_attn/flash_attn_triton_amd/train.py deleted file mode 100644 index 8f39f627bfb..00000000000 --- a/flash_attn/flash_attn_triton_amd/train.py +++ /dev/null @@ -1,404 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim -from torch.utils.data import DataLoader, Dataset, random_split -import numpy as np -import pandas as pd -from tqdm import tqdm -import matplotlib.pyplot as plt -from datasets import load_dataset -import flash_attn -import flash_attn.flash_attn_triton_amd.fp8 - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -print(f"using device: {device}") - -# ------------------------------- -# Model -# ------------------------------- -class FlashAttention(nn.Module): - def __init__(self, dim, num_heads=8, causal=True, dropout=0.1, qkv_bias=True, use_fp8=False): - super().__init__() - self.use_fp8 = use_fp8 - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - self.causal = causal - self.dropout_p = dropout - - # qkv and output projections - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - - def forward(self, x): - b, n, c = x.shape - # project to qkv - qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - # reshape for flash attention function - qkv_packed = torch.stack([q, k, v], dim=2).reshape(b, n, 3, self.num_heads, self.head_dim) - - # use the appropriate flash attention function - if self.use_fp8: - context = flash_attn.flash_attn_triton_amd.fp8.flash_attn_qkvpacked_fp8_func( - qkv_packed, - dropout_p=self.dropout_p, - causal=self.causal - ) - else: - context = flash_attn.flash_attn_qkvpacked_func( - qkv_packed, - dropout_p=self.dropout_p, - causal=self.causal - ) - - # convert back to original shape and type - context = context.reshape(b, n, c) - - # output projection - x = self.proj(context) - - return x - -class TransformerBlock(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4.0, causal=True, dropout=0.1, use_fp8=False): - super().__init__() - self.norm1 = nn.LayerNorm(dim) - self.attn = FlashAttention(dim, num_heads=num_heads, causal=causal, dropout=dropout, use_fp8=use_fp8) - - self.norm2 = nn.LayerNorm(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = nn.Sequential( - nn.Linear(dim, mlp_hidden_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(mlp_hidden_dim, dim), - nn.Dropout(dropout) - ) - - def forward(self, x): - x = x + self.attn(self.norm1(x)) - x = x + self.mlp(self.norm2(x)) - return x - -class FlashLM(nn.Module): - def __init__( - self, - vocab_size, - dim=256, - depth=6, - num_heads=8, - mlp_ratio=4.0, - causal=True, - dropout=0.1, - max_seq_len=256, - use_fp8=False - ): - super().__init__() - - # embedding layers - self.token_embedding = nn.Embedding(vocab_size, dim) - self.position_embedding = nn.Parameter(torch.zeros(1, max_seq_len, dim)) - self.dropout = nn.Dropout(dropout) - - # transformer blocks - self.blocks = nn.ModuleList([ - TransformerBlock(dim, num_heads, mlp_ratio, causal=causal, dropout=dropout, use_fp8=use_fp8) - for _ in range(depth) - ]) - - # lm head: project back to vocabulary dimension for each token - self.norm = nn.LayerNorm(dim) - self.lm_head = nn.Linear(dim, vocab_size) - - def forward(self, x): - b, n = x.shape - - # token + positional embedding - x = self.token_embedding(x) - x = x + self.position_embedding[:, :n, :] - x = self.dropout(x) - - # transformer blocks - for block in self.blocks: - x = block(x) - - # language modeling head - x = self.norm(x) - logits = self.lm_head(x) # shape: (b, n, vocab_size) - return logits - -# ------------------------------- -# Data -# ------------------------------- -class TextDataset(Dataset): - def __init__(self, sequences, max_len=None): - self.sequences = sequences - self.max_len = max_len - - def __len__(self): - return len(self.sequences) - - def __getitem__(self, idx): - seq = self.sequences[idx] - # input: all tokens except the last, target: all tokens except the first - return (torch.tensor(seq[:-1], dtype=torch.long), - torch.tensor(seq[1:], dtype=torch.long)) - -class VarLenTextDataset(Dataset): - def __init__(self, sequences, max_len=256): - self.sequences = sequences - self.max_len = max_len - - def __len__(self): - return len(self.sequences) - - def __getitem__(self, idx): - seq = self.sequences[idx] - # Ensure the sequence doesn't exceed max_len+1 - seq = seq[:self.max_len+1] - # input: all tokens except the last, target: all tokens except the first - return (torch.tensor(seq[:-1], dtype=torch.long), - torch.tensor(seq[1:], dtype=torch.long)) - -def prepare_dataset(batch_size, is_varlen=False, min_len=10, max_len=256, ratio_shorter=0.7): - # load the WikiText-2 - dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") - - # build vocabulary - corpus = " ".join([line for line in dataset["text"] if line.strip() != ""]) # join non-empty lines into a single corpus string - tokens = corpus.split() - vocab = sorted(set(tokens)) - word2idx = {word: idx for idx, word in enumerate(vocab)} - token_ids = [word2idx[word] for word in tokens] - - num_workers = 2 - if is_varlen: - # VARIABLE LENGTH: create sequences of different lengths - sequences = [] - for i in range(0, len(token_ids) - max_len, max_len // 2): # overlap to get more sequences - # Decide target length for this sequence - if np.random.random() < ratio_shorter: - # Shorter sequence - target_len = np.random.randint(min_len + 1, max_len + 1) - else: - # Full length sequence - target_len = max_len + 1 - - # Extract sequence up to target length or whatever's available - seq_end = min(i + target_len, len(token_ids)) - seq = token_ids[i:seq_end] - - # Only keep sequences that are long enough - if len(seq) > min_len + 1: # +1 because we need both input and target - sequences.append(seq) - - print(f"Created {len(sequences)} variable-length sequences") - - # Get some statistics - lens = [len(seq) for seq in sequences] - print(f"Sequence length stats: min={min(lens)}, max={max(lens)}, mean={np.mean(lens):.1f}") - - # split dataset - num_samples = len(sequences) - num_train = int(0.8 * num_samples) - num_val = num_samples - num_train - - # Use appropriate dataset class based on whether we need variable length - dataset_class = VarLenTextDataset - train_sequences = sequences[:num_train] - val_sequences = sequences[num_train:] - - train_dataset = dataset_class(train_sequences, max_len) - val_dataset = dataset_class(val_sequences, max_len) - - - # collate function - def collate_fn(batch): - """ - Collate function that creates a flat representation for variable length flash attention. - """ - # Separate inputs and targets - inputs, targets = zip(*batch) - - # Get sequence lengths - seq_lens = torch.tensor([len(x) for x in inputs], dtype=torch.int32) - - # Concatenate inputs and targets into single tensors - flat_inputs = torch.cat(inputs) - flat_targets = torch.cat(targets) - - # Create cumulative sequence lengths tensor - cu_seqlens = torch.zeros(len(seq_lens) + 1, dtype=torch.int32) - cu_seqlens[1:] = torch.cumsum(seq_lens, dim=0) - - # Calculate max sequence length for this batch - max_seqlen = seq_lens.max().item() - - return flat_inputs, flat_targets, seq_lens, cu_seqlens, max_seqlen - - # data loaders - train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn) - val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn) - else: - # FIXED LENGTH: create sequences of length max_len+1 - sequences = [] - for i in range(0, len(token_ids) - max_len, max_len): - seq = token_ids[i : i + max_len + 1] - if len(seq) == max_len + 1: - sequences.append(seq) - - # split dataset - num_samples = len(sequences) - num_train = int(0.8 * num_samples) - num_val = num_samples - num_train - train_dataset, val_dataset = random_split(TextDataset(sequences), [num_train, num_val]) - - # data loaders - train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) - val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) - - vocab_size = len(vocab) - print(f"vocab size: {vocab_size}, train samples: {len(train_dataset)}, validation samples: {len(val_dataset)}") - return train_dataloader, val_dataloader, vocab_size - -# ------------------------------- -# Training -# ------------------------------- -def train_lm(model, train_dataloader, val_dataloader, optimizer, criterion, num_epochs): - train_losses = [] - val_losses = [] - for epoch in range(num_epochs): - # Training phase - model.train() - epoch_train_loss = 0.0 - for inputs, targets in tqdm(train_dataloader, desc=f"epoch {epoch+1}/{num_epochs} [train]"): - inputs, targets = inputs.to(device), targets.to(device) - - optimizer.zero_grad() - logits = model(inputs) - loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) - loss.backward() - optimizer.step() - - epoch_train_loss += loss.item() - - epoch_train_loss /= len(train_dataloader) - train_losses.append(epoch_train_loss) - print(f"epoch {epoch+1}/{num_epochs} - train loss: {epoch_train_loss:.4f}") - - # Validation phase - model.eval() - epoch_val_loss = 0.0 - with torch.no_grad(): - for inputs, targets in tqdm(val_dataloader, desc=f"epoch {epoch+1}/{num_epochs} [validation]"): - inputs, targets = inputs.to(device), targets.to(device) - logits = model(inputs) - loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) - epoch_val_loss += loss.item() - epoch_val_loss /= len(val_dataloader) - val_losses.append(epoch_val_loss) - print(f"epoch {epoch+1}/{num_epochs} - validation loss: {epoch_val_loss:.4f}") - - return train_losses, val_losses - -# ------------------------------- -# Main -# ------------------------------- -def main(): - # hyperparameters - batch_size = 16 - num_epochs = 20 - learning_rate = 3e-4 - max_len = 128 # total length including both input and target tokens - is_varlen = False - causal=True - dropout=0.1 - - # prep data - print("Preparing Dataset") - train_dataloader, val_dataloader, vocab_size = prepare_dataset(batch_size, max_len=max_len, is_varlen=is_varlen) - - # create language models - print("Creating Models") - model_normal = FlashLM( - vocab_size=vocab_size, - dim=256, - depth=3, - num_heads=8, - causal=causal, - dropout=dropout, - max_seq_len=max_len, - ).to(device) - - model_fp8 = FlashLM( - vocab_size=vocab_size, - dim=256, - depth=3, - num_heads=8, - causal=causal, - dropout=dropout, - max_seq_len=max_len, - use_fp8=True - ).to(device) - - # Train Normal model - print("Starting training for Normal model...") - optimizer_normal = optim.AdamW(model_normal.parameters(), lr=learning_rate) - criterion = nn.CrossEntropyLoss() - normal_train_losses, normal_val_losses = train_lm( - model_normal, train_dataloader, val_dataloader, optimizer_normal, criterion, num_epochs - ) - torch.save(model_normal.state_dict(), 'flash_lm_normal.pth') - print("Normal model training complete and saved.") - - # Train FP8 model - print("Starting training for FP8 model...") - optimizer_fp8 = optim.AdamW(model_fp8.parameters(), lr=learning_rate) - fp8_train_losses, fp8_val_losses = train_lm( - model_fp8, train_dataloader, val_dataloader, optimizer_fp8, criterion, num_epochs - ) - torch.save(model_fp8.state_dict(), 'flash_lm_fp8.pth') - print("FP8 model training complete and saved.") - - # save losses to csv - epochs = range(1, num_epochs+1) - loss_data = { - "Epoch": epochs, - "Normal_Training_Loss": normal_train_losses, - "Normal_Validation_Loss": normal_val_losses, - "FP8_Training_Loss": fp8_train_losses, - "FP8_Validation_Loss": fp8_val_losses, - } - df_losses = pd.DataFrame(loss_data) - df_losses.to_csv("losses.csv", index=False) - print("Loss data saved to losses.csv") - - # plot Training Loss - plt.figure(figsize=(10, 6)) - plt.plot(epochs, normal_train_losses, label="Normal Training Loss", marker='o') - plt.plot(epochs, fp8_train_losses, label="FP8 Training Loss", marker='x') - plt.xlabel("Epoch") - plt.ylabel("Training Loss") - plt.title("Training Loss Comparison: Normal vs FP8 Flash Attention") - plt.legend() - plt.grid(True) - plt.savefig("training_loss.png") # Saves the training loss plot to disk - plt.show() - - # Plot Validation Loss - plt.figure(figsize=(10, 6)) - plt.plot(epochs, normal_val_losses, label="Normal Validation Loss", marker='o') - plt.plot(epochs, fp8_val_losses, label="FP8 Validation Loss", marker='x') - plt.xlabel("Epoch") - plt.ylabel("Validation Loss") - plt.title("Validation Loss Comparison: Normal vs FP8 Flash Attention") - plt.legend() - plt.grid(True) - plt.savefig("validation_loss.png") # Saves the validation loss plot to disk - plt.show() - - -if __name__ == "__main__": - main() diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 96d4f662567..44502785a35 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -139,11 +139,11 @@ def check_args(self, q, k, v, o): assert q.dim() == 4 assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 assert self.cu_seqlens_q is None and self.cu_seqlens_k is None - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] # and q.shape[-1] == v.shape[-1] # TODO: Change assert if we support qkl f8 and v f16 assert q.dtype == k.dtype and q.dtype == v.dtype - assert o.shape == q.shape + assert o.shape[:-1] == q.shape[:-1] and o.shape[-1] == v.shape[-1] assert (nheads_q % nheads_k) == 0 assert self.layout is not None assert self.layout == 'thd' or not self.varlen @@ -1064,91 +1064,6 @@ def write_dropout_mask(x, tensor_name = "tensor"): else: writer.writerows(dropout_mask) -# ------------------------------- -# Autotune -# ------------------------------- -def get_fwd_prefill_cdna_autotune_configs(): - return [ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] - - -def get_fwd_prefill_rdna_autotune_configs(): - return [ - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] - - -def get_fwd_prefill_autotune_configs(): - if AUTOTUNE: - if is_rdna(): - return get_fwd_prefill_rdna_autotune_configs() - elif is_cdna(): - return get_fwd_prefill_cdna_autotune_configs() - else: - raise ValueError("Unknown Device Type") - else: - arch = get_arch() - if arch == "gfx950": - default_config = triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) - elif arch == "gfx942" and False: # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. - default_config = triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) - else: - default_config = triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) - - return [ - default_config - ], [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_BLOCK_DMODEL", - "IS_VARLEN", - "HQ", - "HK", - ] - # ------------------------------- # Runtime info # ------------------------------- diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py old mode 100644 new mode 100755 index 44d1f027cb0..8ddfa37bd31 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -2,16 +2,24 @@ from typing import Optional, Union, List, Tuple +import os import torch import torch.nn as nn -# isort: off -# We need to import the CUDA kernels after importing torch -import flash_attn_3._C # Registers operators with PyTorch -# isort: on +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +if USE_TRITON_ROCM: + import sys + sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from flash_attn.flash_attn_triton_amd import interface_fa_v3 as flash_attn_3_gpu +else: + # isort: off + # We need to import the CUDA kernels after importing torch + import flash_attn_3._C # Registers operators with PyTorch -flash_attn_3_cuda = torch.ops.flash_attn_3 + # isort: on + + flash_attn_3_gpu = torch.ops.flash_attn_3 def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -90,7 +98,7 @@ def _flash_attn_forward( ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] seqlens_rotary = maybe_contiguous(seqlens_rotary) - out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd( + out, softmax_lse, *rest = flash_attn_3_gpu.fwd( q, k, v, @@ -268,7 +276,7 @@ def _flash_attn_backward( ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - softmax_d, *rest = flash_attn_3_cuda.bwd( + dq, dk, dv, softmax_d, *rest = flash_attn_3_gpu.bwd( dout, q, k, @@ -922,7 +930,7 @@ def flash_attn_varlen_func( def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None): - return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype) + return flash_attn_3_gpu.fwd_combine(out_partial, lse_partial, out, out_dtype) def flash_attn_with_kvcache( @@ -1110,7 +1118,7 @@ def get_scheduler_metadata( cache_seqlens = maybe_contiguous(cache_seqlens) if headdim_v is None: headdim_v = headdim - scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata( + scheduler_metadata = flash_attn_3_gpu.get_scheduler_metadata( batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v, qkv_dtype, cache_seqlens, diff --git a/hopper/setup.py b/hopper/setup.py old mode 100644 new mode 100755 index 95729edabe2..36359229766 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -43,6 +43,10 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" +# ROCm specific settings +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +if USE_TRITON_ROCM: + SKIP_CUDA_BUILD = True DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" @@ -421,10 +425,10 @@ def nvcc_threads_args(): cmdclass = {} ext_modules = [] - # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. -subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"]) +if not USE_TRITON_ROCM: + subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"]) if not SKIP_CUDA_BUILD: print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) diff --git a/hopper/test_flash_attn_triton_amd.py b/hopper/test_flash_attn_triton_amd.py new file mode 100755 index 00000000000..738ec1d8c13 --- /dev/null +++ b/hopper/test_flash_attn_triton_amd.py @@ -0,0 +1,1174 @@ +import os +import math +import itertools + +import pytest +import torch +import torch.nn.functional as F +from torch._C import parse_schema + +from einops import rearrange, repeat +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from padding import pad_input, unpad_input +from test_util import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, +) + +from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine +from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata + + +DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "TRUE") == "TRUE" +DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" +DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "TRUE") == "TRUE" +DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "TRUE") == "TRUE" +DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "TRUE") == "TRUE" +DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "TRUE") == "TRUE" +DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" +DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 +DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" +DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" +DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" +DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" +DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" + +COMPILED_HDIMS = ( + [] + + ([64] if not DISABLE_HDIM64 else []) + + ([96] if not DISABLE_HDIM96 else []) + + ([128] if not DISABLE_HDIM128 else []) + + ([192] if not DISABLE_HDIM192 else []) + + ([256] if not DISABLE_HDIM256 else []) +) + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("V_colmajor", [False, True]) +@pytest.mark.parametrize("V_colmajor", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128, 192]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype +): + if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): + pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + device = "cuda" + # set seed + torch.random.manual_seed(0) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 1 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4) + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + # window_size = (-1, -1) if not local else (16, 0) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + if V_colmajor: + v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # if qv is not None: + # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out = flash_attn_func( + q, + k, + v, + causal=causal, + qv=qv, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g = torch.randn_like(out) + do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + # import flash_attn_3_cuda + # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd( + # g, + # q, + # k, + # v, + # out, + # lse, + # None, + # None, + # None, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (1, 3), + (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype +): + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_q <= 2048 else 2 + nheads = 6 + # batch_size = 2 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, + query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) + q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + qv=qv_unpad, + q_descale=q_descale, + k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + # import flash_attn_3_cuda + # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( + # g_unpad, + # q_unpad, + # k_unpad, + # v_unpad, + # out_unpad, + # lse, + # None, + # None, + # None, + # cu_seqlens_q, + # cu_seqlens_k, + # None, None, + # max_seqlen_q, + # max_seqlen_k, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) +# @pytest.mark.parametrize("new_kv", [True]) +@pytest.mark.parametrize("new_kv", [False, True]) +@pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) +# @pytest.mark.parametrize("causal,local", [(False, False)]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False]) +@pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) +# @pytest.mark.parametrize("rotary_interleaved", [True]) +@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) +# @pytest.mark.parametrize("page_size", [None]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_leftpad", [False]) +@pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False]) +@pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +@pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # (1, 128 * 1024), + # (16, 128 * 1024), + (128, 128), + (256, 512), # To test appending KV with more than 1 block + (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + mha_type, + dtype, +): + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + has_qv = d == 64 and dv >= 256 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if has_qv: + qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) + if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size)]) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ + :batch_size + ] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() + v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + attention_chunk=attention_chunk, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + attention_chunk=attention_chunk, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + precompute_metadata_vals = [False] + for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + if precompute_metadata: + scheduler_metadata = get_scheduler_metadata( + batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, + max_seqlen_k_new=seqlen_new, page_size=page_size, + causal=causal, window_size=window_size, attention_chunk=attention_chunk, + num_splits=num_splits + ) + else: + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + return_softmax_lse=True + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + + +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + v_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize('d', [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (64, 8192), + ], +) +def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype): + device = "cuda" + torch.random.manual_seed(0) + batch_size = 2 + nheads = 16 + nheads_kv = 4 + # There was a bug where this would cause "unspecified launch failure" due to Cluster + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + for _ in range(100): + flash_attn_func(q, k, v, causal=causal) + + +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128]) +# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [80]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (239, 1), + (3, 799), + (799, 3), + (1024, 128), + (97, 97), + (128, 128), + (200, 200), + (256, 256), + (257, 257), + (384, 384), + (512, 512), + (768, 768), + (1024, 1024), + (2048, 2048), + ], +) +@pytest.mark.skip(reason="Cannot be run in parallel with other tests due to memory usage") +def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(0) + # Simulate under memory load + dummy = torch.empty(70 * 1024 ** 3, dtype=torch.uint8, device=device) + batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger + nheads = 4 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + torch.random.manual_seed(42) + out0 = flash_attn_func(q, k, v, causal=causal) + g = torch.randn_like(out0) + dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g) + # Numerical error if we just do any arithmetic on dq + dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() + + for i in range(1000): + torch.random.manual_seed(42) + out = flash_attn_func(q, k, v, causal=causal) + assert torch.equal(out, out0) + # assert torch.equal(lse, lse0) + + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_equal = torch.allclose(dq, dq0, atol=dq_atol) + if not dq_equal: + print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") + # breakpoint() + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert dq_equal + + +def attention_combine_ref(out_partial, lse_partial): + """ + out_partial: (num_splits, batch_size, seqlen, nheads, d) + lse_partial: (num_splits, batch_size, nheads, seqlen) + """ + lse = torch.logsumexp(lse_partial, dim=0) + scale = torch.exp(lse_partial - lse) + scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) + out = (scale.unsqueeze(-1) * out_partial).sum(0) + return out, lse + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float32]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) +# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) +# @pytest.mark.parametrize("seqlen", [15]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) +# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) +# @pytest.mark.parametrize("num_splits", [128]) +def test_flash_attn_combine(num_splits, seqlen, d, dtype): + if DISABLE_SPLIT: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(1) + batch_size = 5 + nheads = 16 + # batch_size = 1 + # nheads = 1 + out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] # To test non-contiguous tensor + lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor + # To test short-circuiting based on num_splits + lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf") + out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype) + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + out_pt = out_ref.to(dtype) + + print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) + multiple = 2 + assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) + + # from flash_attn.utils.benchmark import pytorch_profiler + # # pytorch_profiler(torch.sum, lse_partial) + # pytorch_profiler(flash_attn_combine, out_partial, lse_partial) + # pytorch_profiler(torch.sum, out_partial) + +@pytest.mark.skip(reason="AMD Triton backend doesn't use torch ops registration") +def test_flash3_bw_compatibility() -> None: + # Let's try to always stay backward compatible! This will make life easier + # for downstream libaries, users, and exported models. + # 1/ Instead of removing arguments, error out if their value is no longer supported + # 2/ When adding arguments, add them at the end with a default value + assert torch.ops.flash_attn_3.fwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, " + "Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, " + "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, " + "int? max_seqlen_q=None, int? max_seqlen_k=None, Tensor? page_table=None, " + "Tensor? kv_batch_idx=None, Tensor? leftpad_k=None, Tensor? rotary_cos=None, Tensor? rotary_sin=None, " + "Tensor? seqlens_rotary=None, Tensor? q_descale=None, Tensor? k_descale=None, Tensor? v_descale=None, " + "float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, " + "Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) " + "-> (Tensor(out!), Tensor, Tensor, Tensor)" + )) + assert torch.ops.flash_attn_3.bwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, " + "Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, " + "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, " + "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, " + "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) " + "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)" + )) + assert torch.ops.flash_attn_3.fwd_combine.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, " + "ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)" + )) + assert torch.ops.flash_attn_3.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, " + "int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, " + "Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, " + "bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, " + "int sm_margin=0) -> Tensor" + )) diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index adebe088a79..ba1932438e2 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -1921,7 +1921,7 @@ def test_flash_attn_splitkv( # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) -@pytest.mark.parametrize("paged_kv_block_size", [None]) +@pytest.mark.parametrize("paged_kv_block_size", [None, 256]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) # @pytest.mark.parametrize("paged_kv_block_size", [None]) @pytest.mark.parametrize("has_leftpad", [False]) From e7b97cbdb9fa4d22c606cb7f7e2f67e58eab7fc2 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Wed, 24 Sep 2025 23:27:06 -0400 Subject: [PATCH 13/27] AITER integration (#159) * clean up v2 interface * assert fp8 scale shapes * rotary working * move rotary to impl layers * remove einops * enable rotarry in v3 * create interface * fix descale assert * unify bwd * lint from aiter * clean fp8 api * add api change * assert shapes for v2 * remove ref and bench.py * remove metadata class and clean up * bwd_prefill * one bwd.py * rename * lint --- flash_attn/flash_attn_interface.py | 2 +- flash_attn/flash_attn_triton_amd/__init__.py | 4 + flash_attn/flash_attn_triton_amd/bench.py | 1391 ----- flash_attn/flash_attn_triton_amd/bwd.py | 4997 +++++++++++++++++ .../bwd_prefill_fused_atomics.py | 1815 ------ .../bwd_prefill_fused_no_atomics.py | 1467 ----- .../bwd_prefill_split.py | 1360 ----- flash_attn/flash_attn_triton_amd/bwd_ref.py | 545 -- .../flash_attn_triton_amd/fwd_decode.py | 643 ++- .../flash_attn_triton_amd/fwd_prefill.py | 1573 ++++-- flash_attn/flash_attn_triton_amd/fwd_ref.py | 889 --- .../flash_attn_triton_amd/interface_fa.py | 927 --- .../flash_attn_triton_amd/interface_fa_v3.py | 660 --- .../flash_attn_triton_amd/interface_v2.py | 674 +++ .../flash_attn_triton_amd/interface_v3.py | 608 ++ flash_attn/flash_attn_triton_amd/utils.py | 1372 +++-- hopper/flash_attn_interface.py | 2 +- hopper/test_flash_attn_triton_amd.py | 5 +- 18 files changed, 8900 insertions(+), 10034 deletions(-) delete mode 100755 flash_attn/flash_attn_triton_amd/bench.py create mode 100755 flash_attn/flash_attn_triton_amd/bwd.py delete mode 100755 flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py delete mode 100755 flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py delete mode 100755 flash_attn/flash_attn_triton_amd/bwd_prefill_split.py delete mode 100644 flash_attn/flash_attn_triton_amd/bwd_ref.py delete mode 100755 flash_attn/flash_attn_triton_amd/fwd_ref.py delete mode 100644 flash_attn/flash_attn_triton_amd/interface_fa.py delete mode 100755 flash_attn/flash_attn_triton_amd/interface_fa_v3.py create mode 100644 flash_attn/flash_attn_triton_amd/interface_v2.py create mode 100755 flash_attn/flash_attn_triton_amd/interface_v3.py diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 865f1db5432..a53b4a3108a 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -10,7 +10,7 @@ # We need to import the CUDA kernels after importing torch USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" if USE_TRITON_ROCM: - from .flash_attn_triton_amd import interface_fa as flash_attn_gpu + from .flash_attn_triton_amd import flash_attn_2 as flash_attn_gpu else: import flash_attn_2_cuda as flash_attn_gpu diff --git a/flash_attn/flash_attn_triton_amd/__init__.py b/flash_attn/flash_attn_triton_amd/__init__.py index e69de29bb2d..78f85fb268f 100644 --- a/flash_attn/flash_attn_triton_amd/__init__.py +++ b/flash_attn/flash_attn_triton_amd/__init__.py @@ -0,0 +1,4 @@ +from . import interface_v2 as flash_attn_2 +from . import interface_v3 as flash_attn_3 + +__all__ = ["flash_attn_2", "flash_attn_3"] diff --git a/flash_attn/flash_attn_triton_amd/bench.py b/flash_attn/flash_attn_triton_amd/bench.py deleted file mode 100755 index e19de575c8c..00000000000 --- a/flash_attn/flash_attn_triton_amd/bench.py +++ /dev/null @@ -1,1391 +0,0 @@ -import os -import sys -import torch -import triton -import time -import argparse -import itertools -import logging -import warnings -import datetime -import pandas as pd -from logging import warning -from typing import Dict, List, Literal, Optional, Tuple -from dataclasses import dataclass -from functools import lru_cache -from utils import get_arch, input_helper - -DEBUG = False - -ENV_FLAGS = ["FLASH_ATTENTION_TRITON_AMD_ENABLE", "FLASH_ATTENTION_TRITON_AMD_AUTOTUNE", "FLASH_ATTENTION_TRITON_AMD_DEBUG"] - -FUNCTIONS = [ - "flash_attn_func", - "flash_attn_fp8_func", - "flash_attn_kvpacked_func", - "flash_attn_qkvpacked_func", - "flash_attn_qkvpacked_fp8_func", - "flash_attn_varlen_func", - "flash_attn_varlen_fp8_func", - "flash_attn_varlen_kvpacked_func", - "flash_attn_varlen_qkvpacked_func", - "flash_attn_varlen_qkvpacked_fp8_func", - "flash_attn_with_kvcache", -] - -SUPPORTED_DTYPES = { - "flash_attn_func": [torch.float16], - "flash_attn_fp8_func": [torch.float8_e4m3fnuz], - "flash_attn_kvpacked_func": [torch.float16], - "flash_attn_qkvpacked_func": [torch.float16], - "flash_attn_qkvpacked_fp8_func": [torch.float16], - "flash_attn_varlen_func": [torch.float16], - "flash_attn_varlen_fp8_func": [torch.float8_e4m3fnuz], - "flash_attn_varlen_kvpacked_func": [torch.float16], - "flash_attn_varlen_qkvpacked_func": [torch.float16], - "flash_attn_varlen_qkvpacked_fp8_func": [torch.float16], - "flash_attn_with_kvcache": [torch.float16], -} - -SUPPORTED_BACKENDS = { - "flash_attn_func": ["ck", "triton"], - "flash_attn_fp8_func": ["triton"], - "flash_attn_kvpacked_func": ["ck", "triton"], - "flash_attn_qkvpacked_func": ["ck", "triton"], - "flash_attn_qkvpacked_fp8_func": ["triton"], - "flash_attn_varlen_func": ["ck", "triton"], - "flash_attn_varlen_fp8_func": ["triton"], - "flash_attn_varlen_kvpacked_func": ["ck", "triton"], - "flash_attn_varlen_qkvpacked_func": ["ck", "triton"], - "flash_attn_varlen_qkvpacked_fp8_func": ["triton"], - "flash_attn_with_kvcache": ["ck", "triton"], -} - -VALID_MODES = ['fwd', 'bwd', 'full'] -SUPPORTED_MODES = { - "flash_attn_func": ["fwd", "bwd", "full"], - "flash_attn_fp8_func": ["fwd", "bwd", "full"], - "flash_attn_kvpacked_func": ["fwd", "bwd", "full"], - "flash_attn_qkvpacked_func": ["fwd", "bwd", "full"], - "flash_attn_qkvpacked_fp8_func": ["fwd", "bwd", "full"], - "flash_attn_varlen_func": ["fwd", "bwd", "full"], - "flash_attn_varlen_fp8_func": ["fwd", "bwd", "full"], - "flash_attn_varlen_kvpacked_func": ["fwd", "bwd", "full"], - "flash_attn_varlen_qkvpacked_func": ["fwd", "bwd", "full"], - "flash_attn_varlen_qkvpacked_fp8_func": ["fwd", "bwd", "full"], - "flash_attn_with_kvcache": ["fwd"], -} - - -# Add a global variable for verbose mode -VERBOSE = False - -@dataclass -class EnvVariableConfig: - key: str - values: List[str] - backend: Optional[Literal["triton", "ck"]] = None - -ENV_VARIABLE_CONFIGS : List[EnvVariableConfig] = [ - # EnvVariableConfig(key="BWD_MODE", values=["split", "fused_atomics", "fused_no_atomics"], backend="triton"), -] - -class FunctionConfig: - def __init__(self, fn_name: str, mode: Literal["fwd", "bwd", "full"], dtype, backend: Literal["triton", "ck"], env_config: Dict): - self.fn_name = fn_name - self.mode: Literal["fwd", "bwd", "full"] = mode - self.dtype = dtype - self.backend: Literal["triton", "ck"] = backend - self.arch = get_arch() - self.env_configs = env_config - - def __str__(self): - # extract base dtype name if it's a torch dtype - dtype_str = str(self.dtype) - if "torch." in dtype_str: - dtype_str = dtype_str.split(".")[-1] - - if len(self.env_configs) > 0: - env_str = "" - for env_key, env_value in self.env_configs.items(): - env_str += f"{env_key}={env_value}" - return f"{self.fn_name}_{self.mode}_{dtype_str}_{self.backend}_{self.arch}_{env_str}" - else: - return f"{self.fn_name}_{self.mode}_{dtype_str}_{self.backend}_{self.arch}" - - def column_name(self): - return f"{self}_ms" -def generate_fn_inputs( - fn_name: str, - BATCH: int, - HQ: int, - HK: int, - N_CTX_Q: int, - N_CTX_K: int, - D_HEAD: int, - CAUSAL: bool, - DROPOUT_P: float, - dtype: torch.dtype, - device: Literal["cpu", "cuda"] - ): - if fn_name == "flash_attn_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", device=device) - elif fn_name == "flash_attn_kvpacked_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", packing="kv", device=device) - elif fn_name == "flash_attn_qkvpacked_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", packing="qkv", device=device) - elif fn_name == "flash_attn_varlen_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", device=device) - elif fn_name == "flash_attn_varlen_kvpacked_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", packing="kv", device=device) - elif fn_name == "flash_attn_varlen_qkvpacked_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", packing="qkv", device=device) - elif fn_name == "flash_attn_with_kvcache": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", device=device) - elif fn_name == "flash_attn_fp8_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", device=device) - elif fn_name == "flash_attn_qkvpacked_fp8_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", packing="qkv", device=device) - elif fn_name == "flash_attn_varlen_fp8_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", device=device) - elif fn_name == "flash_attn_varlen_qkvpacked_fp8_func": - return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", packing="qkv", device=device) - else: - valid_fn_names = ", ".join(FUNCTIONS) - raise ValueError(f"{fn_name} should be one of the following functions. {valid_fn_names}") - -def estimate_memory(config): - batch, hq, hk, sq, sk, d_head, causal, dropout = config - memory_estimate = batch * (hq * sq + hk * sk) * d_head * 4 # bytes - return memory_estimate - -def generate_benchmark_configs(is_varlen: bool, packing: Optional[Literal["kv", "qkv"]]): - """ - generates a small number of configs that cover the parameter space well - """ - - # define all parameter options as lists - batch_sizes = [1, 64] - if packing == "qkv": - hq_values = hk_values = [2, 8] - sq_values = sk_values = [256, 8192] - else: - if is_varlen: # make sure the seqlen is greater than the batchsize so that subsequences are greater than 0 - hq_values = [16, 32] # test mqa/gqa - hk_values = [8, 16] - sq_values = [128, 512] - sk_values = [512, 2024] - else: - hq_values = [64, 128] # test mqa/gqa - hk_values = [16, 64] - sq_values = [4, 4096] - sk_values = [4096, 16384] # test large k values for inference perf - d_head_values = [64, 128] - causal_values = [True, False] # most models usual causal True - dropout_values = [0.0, 0.1] - - # generate all fn_configs without inputs - input_configs = [] - - # one big loop to generate configs - for batch in batch_sizes: - for hq in hq_values: - for hk in hk_values: - for sq in sq_values: - for sk in sk_values: - for d_head in d_head_values: - for causal in causal_values: - for dropout in dropout_values: - # filter configs - input_config = (batch, hq, hk, sq, sk, d_head, causal, dropout) - - # skip if memory usage would be too high - if estimate_memory(input_config) > 8 * 1024 * 1024 * 1024: # 8 GB limit - continue - - # we need hq to be a multiple of hk - if hq % hk != 0: - continue - - # for qkvpacked functions, q and k must have same dimensions - if packing == "qkv" and (sq != sk or hq != hk): - continue - - input_configs.append(input_config) - - return input_configs - -def create_benchmark_fn( - flash_attn, - fn_name, - fn_input, - mode: Literal["fwd", "bwd", "full"] -): - if DEBUG: - print("create_benchmark_fn") - print("flash_attn:", flash_attn) - print("fn_name:", fn_name) - print("fn_input:", len(fn_input)) - print("mode:", mode) - - if fn_name == "flash_attn_func": - q, k, v, do, metadata = fn_input - if mode == "fwd": - def flash_attn_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_func( - q, - k, - v, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out - elif mode == "bwd": - out, lse, S_dmask = flash_attn.flash_attn_func( - q, - k, - v, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_bench_fn(): - dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) - return dq, dk, dv - elif mode == "full": - def flash_attn_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_func( - q, - k, - v, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) - return dq, dk, dv - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_bench_fn - - elif fn_name == "flash_attn_kvpacked_func": - q, kv, do, metadata = fn_input - if mode == "fwd": - def flash_attn_kvpacked_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_kvpacked_func( - q, - kv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out - elif mode == "bwd": - out, lse, S_dmask = flash_attn.flash_attn_kvpacked_func( - q, - kv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_kvpacked_bench_fn(): - dq, dkv = torch.autograd.grad(out, (q, kv), do, retain_graph=True) - return dq, dkv - elif mode == "full": - def flash_attn_kvpacked_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_kvpacked_func( - q, - kv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dq, dkv = torch.autograd.grad(out, (q, kv), do, retain_graph=True) - return dq, dkv - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_kvpacked_bench_fn - elif fn_name == "flash_attn_qkvpacked_func": - qkv, do, metadata = fn_input - if mode == "fwd": - def flash_attn_qkvpacked_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_func( - qkv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out - elif mode == "bwd": - out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_func( - qkv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_qkvpacked_bench_fn(): - dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) - return dqkv - elif mode == "full": - def flash_attn_qkvpacked_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_func( - qkv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) - return dqkv - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_qkvpacked_bench_fn - elif fn_name == "flash_attn_varlen_func": - q_unpad, k_unpad, v_unpad, do_unpad, metadata = fn_input - if mode == "fwd": - def flash_attn_varlen_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_func( - q_unpad, - k_unpad, - v_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out_unpad - elif mode == "bwd": - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_func( - q_unpad, - k_unpad, - v_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_varlen_bench_fn(): - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) - return dq_unpad, dk_unpad, dv_unpad - elif mode == "full": - def flash_attn_varlen_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_func( - q_unpad, - k_unpad, - v_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) - return dq_unpad, dk_unpad, dv_unpad - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_varlen_bench_fn - elif fn_name == "flash_attn_varlen_kvpacked_func": - q_unpad, kv_unpad, do_unpad, metadata = fn_input - if mode == "fwd": - def flash_attn_varlen_kvpacked_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out_unpad - elif mode == "bwd": - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_varlen_kvpacked_bench_fn(): - dq_unpad, dkv_unpad = torch.autograd.grad(out_unpad, (q_unpad, kv_unpad), do_unpad, retain_graph=True) - return dq_unpad, dkv_unpad - elif mode == "full": - def flash_attn_varlen_kvpacked_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dq_unpad, dkv_unpad = torch.autograd.grad(out_unpad, (q_unpad, kv_unpad), do_unpad, retain_graph=True) - return dq_unpad, dkv_unpad - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_varlen_kvpacked_bench_fn - elif fn_name == "flash_attn_varlen_qkvpacked_func": - qkv_unpad, do_unpad, metadata = fn_input - if mode == "fwd": - def flash_attn_varlen_qkvpacked_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_func( - qkv_unpad, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out_unpad - elif mode == "bwd": - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_func( - qkv_unpad, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_varlen_qkvpacked_bench_fn(): - dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) - return dqkv_unpad - elif mode == "full": - def flash_attn_varlen_qkvpacked_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_func( - qkv_unpad, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) - return dqkv_unpad - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_varlen_qkvpacked_bench_fn - elif fn_name == "flash_attn_with_kvcache": - q, k_cache, v_cache, _, metadata = fn_input - if mode == "fwd": - def flash_attn_with_kvcache_bench_fn(): - out = flash_attn.flash_attn_with_kvcache( - q, - k_cache, - v_cache, - None, - None, - rotary_cos=None, - rotary_sin=None, - cache_seqlens=None, - cache_batch_idx=None, - cache_leftpad=None, - block_table=None, - causal=metadata.causal, - window_size=(-1, -1), - rotary_interleaved=False, - alibi_slopes=None, - num_splits=0, - ) - return out - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_with_kvcache_bench_fn - elif fn_name == "flash_attn_fp8_func": - (q, descale_q), (k, descale_k), (v, descale_v), (do, descale_do), metadata = fn_input - if mode == "fwd": - def flash_attn_f8_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_fp8_func( - q, - k, - v, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out - elif mode == "bwd": - out, lse, S_dmask = flash_attn.flash_attn_fp8_func( - q, - k, - v, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_f8_bench_fn(): - dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) - return dq, dk, dv - elif mode == "full": - def flash_attn_f8_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_fp8_func( - q, - k, - v, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) - return dq, dk, dv - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_f8_bench_fn - elif fn_name == "flash_attn_qkvpacked_fp8_func": - qkv, do, metadata = fn_input - if mode == "fwd": - def flash_attn_qkvpacked_fp8_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( - qkv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out - elif mode == "bwd": - out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( - qkv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_qkvpacked_fp8_bench_fn(): - dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) - return dqkv - elif mode == "full": - def flash_attn_qkvpacked_fp8_bench_fn(): - out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( - qkv, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0, - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) - return dqkv - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_qkvpacked_fp8_bench_fn - elif fn_name == "flash_attn_varlen_fp8_func": - (q_unpad, descale_q), (k_unpad, descale_k), (v_unpad, descale_v), (do_unpad, descale_do), metadata = fn_input - if mode == "fwd": - def flash_attn_varlen_fp8_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( - q_unpad, - k_unpad, - v_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out_unpad - elif mode == "bwd": - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( - q_unpad, - k_unpad, - v_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_varlen_fp8_bench_fn(): - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) - return dq_unpad, dk_unpad, dv_unpad - elif mode == "full": - def flash_attn_varlen_fp8_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( - q_unpad, - k_unpad, - v_unpad, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) - return dq_unpad, dk_unpad, dv_unpad - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_varlen_fp8_bench_fn - elif fn_name == "flash_attn_varlen_qkvpacked_fp8_func": - qkv_unpad, do_unpad, metadata = fn_input - if mode == "fwd": - def flash_attn_varlen_qkvpacked_fp8_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( - qkv_unpad, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - return out_unpad - elif mode == "bwd": - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( - qkv_unpad, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - def flash_attn_varlen_qkvpacked_fp8_bench_fn(): - dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) - return dqkv_unpad - elif mode == "full": - def flash_attn_varlen_qkvpacked_fp8_bench_fn(): - out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( - qkv_unpad, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - metadata.dropout_p, - causal=metadata.causal, - window_size=(-1, -1), - softcap=0.0 , - alibi_slopes=None, - deterministic=False, - return_attn_probs=True, - ) - dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) - return dqkv_unpad - else: - raise ValueError(f"Unsupported benchmarking mode: {mode}") - - return flash_attn_varlen_qkvpacked_fp8_bench_fn - else: - valid_fn_names = ", ".join(FUNCTIONS) - raise ValueError(f"{fn_name} should be one of the following functions. {valid_fn_names}") - -def get_packing_type(fn_name: str) -> Optional[Literal["kv", "qkv"]]: - if "_kvpacked" in fn_name: - packing = "kv" - elif "_qkvpacked" in fn_name: - packing = "qkv" - else: - packing = None - - return packing - -def load_flash_attn_module(backend: Literal["triton", "ck"], env_configs: Dict = {}): - """ - Load the flash_attn module with the specified backend configuration - """ - global VERBOSE - - # remove any existing env variables first - for key in ENV_FLAGS: - if key in os.environ: - del os.environ[key] - - # set environment variable for the desired backend - if backend == "triton": - os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE" - os.environ["FLASH_ATTENTION_TRITON_AMD_DEBUG"] = "0" - os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "1" - elif backend == "ck": - os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "FALSE" - else: - raise ValueError(f"Unknown backend {backend}") - - # add custom env configs - add_env_configs(env_configs) - - if VERBOSE: # Only print if both local and global verbose are True - print(f"Loading flash_attn module with {backend} backend.") - - # Remove any existing flash_attn modules from sys.modules - for module_name in list(sys.modules.keys()): - if module_name.startswith('flash_attn'): - del sys.modules[module_name] - - # Clear CUDA cache - torch.cuda.empty_cache() - - # Import and return the module - import flash_attn - - # disable triton printing from autotuning - if not VERBOSE: - os.environ["TRITON_PRINT_AUTOTUNING"] = "0" - - return flash_attn - -def add_env_configs(env_config: Dict): - for env_key, env_value in env_config.items(): - if env_key in os.environ: - del os.environ[env_key] # remove previous version so that env key is the latest key added - os.environ[env_key] = env_value - -def run_benchmark(func_config: FunctionConfig, input_configs): - """ - Runs the benchmark for the provided function configuration with the given input configurations. - """ - global VERBOSE - - # extract function configuration parameters - fn_name = func_config.fn_name - mode = func_config.mode - dtype = func_config.dtype - backend = func_config.backend - - # load flash attention module - flash_attn_module = load_flash_attn_module(backend, func_config.env_configs) - - # start timing the benchmark - start_time = time.time() - if VERBOSE: - print(f"Benchmarking {func_config} ...") - else: - print(f"Running {fn_name} ({mode}, {backend})...", end='', flush=True) - - # Setup benchmark configurations - bench_configs = [ - triton.testing.Benchmark( - x_names=["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"], - x_vals=list(input_configs.keys()), - line_arg="provider", - line_vals=["triton"], - line_names=["Time (ms)"], - styles=[("red", "-")], - ylabel="ms", - plot_name=f"benchmark-{func_config}", - args={ - }, - ) - ] - - @triton.testing.perf_report(bench_configs) - def bench_function( - BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT, provider, device="cuda" - ): - if DEBUG: - print("BATCH:", BATCH) - print("HQ:", HQ) - print("HK:", HK) - print("N_CTX_Q:", N_CTX_Q) - print("N_CTX_Q:", N_CTX_Q) - print("D_HEAD:", D_HEAD) - print("CAUSAL:", CAUSAL) - print("DROPOUT:", DROPOUT) - print("mode:", mode) - print("provider:", provider) - print("device:", device) - fn_input = input_configs[(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT)] - benchmark_fn = create_benchmark_fn(flash_attn_module, fn_name, fn_input, mode) - - # run the benchmark - ms = triton.testing.do_bench(benchmark_fn, warmup=25, rep=100) - return ms - - df = bench_function.run(return_df=True)[0] - - # set the column name to reflect the function configuration - df = df.rename(columns={"Time (ms)": func_config.column_name()}) - - # calculate and print elapsed time - elapsed_time = time.time() - start_time - - return df, elapsed_time - -def filter_modes(requested_modes, fn_name, supported_modes_for_fn): - modes_to_run = [] - if requested_modes: - for mode in requested_modes: - if mode in supported_modes_for_fn: - modes_to_run.append(mode) - else: - warning(f"Mode '{mode}' requested but not supported by function '{fn_name}'. Skipping this mode for this function.") - else: - modes_to_run = ["full" if "full" in supported_modes_for_fn else "fwd"] - return modes_to_run - -def get_env_value_combinations(current_backend: Optional[Literal["triton", "ck"]]) -> List[Dict[str, str]]: - # filter environment variations applicable to the current backend - applicable_variations = [ - var_config for var_config in ENV_VARIABLE_CONFIGS - if var_config.backend is None or var_config.backend == current_backend - ] - - if not applicable_variations: - # no applicable variations, return list with empty dict - return [{}] - - # prepare keys and value lists - variation_keys = [v.key for v in applicable_variations] - variation_value_lists = [v.values for v in applicable_variations] - - # generate all combinations as dictionaries directly - env_configs = [] - for value_combination in itertools.product(*variation_value_lists): - env_configs.append(dict(zip(variation_keys, value_combination))) - - return env_configs - -def get_input_config_set(config_type): - if config_type == "llama": - # batch, hq, hk, sq, sk, d_head, causal, dropout - input_configs = [ - # LLaMA 3 8B - (4, 32, 8, 8192, 8192, 128, True, 0.0), - # LLaMA 3 70B - (4, 64, 8, 8192, 8192, 128, True, 0.0), - ] - else: - raise ValueError(f"Unknown input config: {config_type}") - - return input_configs - -def available_backends(): - """Check which backends are available by trying to load them.""" - available = [] - - for backend in ["triton", "ck"]: - try: - # try loading the module with this backend - load_flash_attn_module(backend) - available.append(backend) - except Exception as e: - # backend not available, just continue - if DEBUG: - print(f"Backend {backend} not available: {e}") - - if not available: - raise ValueError("No backends are available. Please check your flash_attn installation.") - - return available - -# 2. Simplify get_fn_params to remove the backend filtering logic here -@lru_cache() -def get_fn_params(fn_name): - # get params for fn - packing = get_packing_type(fn_name) - is_varlen = True if "varlen" in fn_name else False - is_fp8 = True if "fp8" in fn_name else False - supported_dtypes = SUPPORTED_DTYPES.get(fn_name, [torch.float16]) - supported_backends = SUPPORTED_BACKENDS.get(fn_name, ["triton"]) # just get what the function supports - supports_backward = False if fn_name in ["flash_attn_with_kvcache"] else True - supported_modes = SUPPORTED_MODES.get(fn_name, ["fwd"]) - device = "cuda" - - # get supported env configs for each backend - supported_env_configs = {} - for backend in supported_backends: - supported_env_configs[backend] = get_env_value_combinations(backend) - - # check backward pass support - if not supports_backward: - warning(f"{fn_name} does not have a backward pass so benching forward pass only.") - - return is_varlen, is_fp8, packing, supported_dtypes, supported_backends, supported_modes, supported_env_configs, device - -# 3. Create a new simpler function to validate and filter backends -def validate_backends(requested_backends, supported_backends, fn_name): - """Validate that requested backends are available and supported.""" - # get actually available backends - available = available_backends() - - # determine which backends to use - if requested_backends: - # user specified backends - validate them - valid_backends = [] - for backend in requested_backends: - if backend not in available: - warning(f"Backend '{backend}' is not available on this system. Skipping.") - continue - if backend not in supported_backends: - warning(f"Backend '{backend}' is not supported by function '{fn_name}'. Skipping.") - continue - valid_backends.append(backend) - - if not valid_backends: - raise ValueError(f"None of the requested backends {requested_backends} are available and supported for {fn_name}") - - return valid_backends - else: - # no backends specified - use all available and supported - valid_backends = [b for b in supported_backends if b in available] - - if not valid_backends: - raise ValueError(f"No available backends found for {fn_name}. Function supports {supported_backends} but only {available} are available.") - - return valid_backends - -# 4. Update process_args to use the new validate_backends function -def process_args(): - """ - Parses command-line arguments and returns function configs and input configs. - """ - global VERBOSE - - # create parser - parser = argparse.ArgumentParser( - prog="Benchmark FlashAttention", - allow_abbrev=False, - ) - # functions - parser.add_argument( - "-benchmark_fn", - type=str, - nargs="*", - choices=FUNCTIONS, - required=True, - help=f"Function(s) to benchmark", - ) - parser.add_argument( - "--mode", - type=str, - nargs='*', - choices=VALID_MODES, - default=["fwd", "bwd"], - help=f"Benchmarking mode(s) to run. Default: fwd, bwd", - ) - parser.add_argument( - "--backend", - type=str, - nargs='*', - choices=["triton", "ck"], - default=["triton"], - help="Backend(s) to run. Default: triton", - ) - parser.add_argument( - "--output", - type=str, - choices=["ms", "tflops"], - default="tflops", - help="Output metric type: ms (milliseconds) or tflops (TFLOPS). Default: tflops", - ) - parser.add_argument( - "--format", - type=str, - choices=["csv", "markdown"], - default="csv", - help="Output file format: csv or markdown. Default: csv", - ) - parser.add_argument( - "--verbose", "-v", - action="store_true", - help="Enable verbose output (show autotuning details)", - ) - # config - parser.add_argument("-b", type=int, default=None, help="Batch size") - parser.add_argument("-hq", type=int, default=None, help="Q Number of heads") - parser.add_argument("-hk", type=int, default=None, help="K and V Number of heads") - parser.add_argument("-sq", type=int, default=None, help="Q Sequence Length") - parser.add_argument("-sk", type=int, default=None, help="K and V Sequence Length") - parser.add_argument("-d", type=int, default=None, help="Head Dimension") - parser.add_argument("-causal", action="store_true", default=None, help="Causal") - parser.add_argument("-dropout", type=float, default=None, help="Dropout") - - # parse args - args = parser.parse_args() - - # Set global verbose flag - VERBOSE = args.verbose - - # parse function args - benchmark_fns = args.benchmark_fn - requested_modes = args.mode - requested_backends = args.backend - output_type: Literal["ms", "tflops"] = args.output - output_format: Literal["csv", "markdown"] = args.format - - # generate function configurations and input configurations separately - all_function_configs = [] - all_input_configs = {} # Maps function config -> input configs - - for fn_name in benchmark_fns: - is_varlen, is_fp8, packing, supported_dtypes, supported_backends, supported_modes_for_fn, supported_env_configs, device = get_fn_params(fn_name) - - # Generate or use custom input configurations - if args.b or args.hq or args.hk or args.sq or args.sk or args.d: - assert args.b and args.hq and args.sq and args.d, ( - "if custom config is specified, please provide at least batch, number of Q heads, Q sequence length, and head size." - ) - - batch = args.b - hq = args.hq - hk = args.hk if args.hk is not None else args.hq - sq = args.sq - sk = args.sk if args.sk is not None else args.sq - d_head = args.d - causal = args.causal if args.causal is not None else False - dropout = args.dropout if args.dropout is not None else 0.0 - input_configs = [(batch, hq, hk, sq, sk, d_head, causal, dropout)] - else: - input_configs = get_input_config_set("llama") - - # filter by mode - modes_to_run = filter_modes(requested_modes, fn_name, supported_modes_for_fn) - if not modes_to_run: - warning(f"No valid modes to run for function '{fn_name}' based on request and function support. Skipping this function.") - continue - - # validate and filter backends - try: - backends_to_run = validate_backends(requested_backends, supported_backends, fn_name) - except ValueError as e: - warning(str(e)) - continue - - # create a function config for each backend and dtype combination - for backend in backends_to_run: - for dtype in supported_dtypes: - for mode in modes_to_run: - for env_config in supported_env_configs[backend]: - func_config = FunctionConfig(fn_name, mode, dtype, backend, env_config) - all_function_configs.append(func_config) - - # Generate inputs for this function configuration - fn_inputs = {} - for input_config in input_configs: - fn_inputs[input_config] = generate_fn_inputs(fn_name, *input_config, dtype, device) - - all_input_configs[func_config] = fn_inputs - - return all_function_configs, all_input_configs, output_type, output_format - -def check_environment_variables(): - for key in ENV_FLAGS: - if key in os.environ: - raise ValueError(f"Running with {key} environment variable is not recommended for the benching script. Use --help to see how to use the benching script.") - -def compute_flops(batch, hq, hk, sq, sk, d_head, causal): - # 2 FLOPs per multiply‑add - if causal: - valid_pairs = ((sk * (sk + 1)) // 2 if sq > sk else - sq * sk - (sq * (sq - 1)) // 2) - else: - valid_pairs = sq * sk - return 2 * batch * hq * valid_pairs * d_head - -# see ref, https://github.com/ROCm/aiter/blob/jukorhon/mha-bwd/op_benchmarks/triton/bench_mha.py -def _flops_single_row(row: pd.Series, mode: str) -> float: - b, hq, d_head = int(row["BATCH"]), int(row["HQ"]), int(row["D_HEAD"]) - sq, sk = int(row["N_CTX_Q"]), int(row["N_CTX_K"]) - causal = bool(row["CAUSAL"]) - - # -------- number of (query, key) products per head ---------------- - if not causal: - valid_pairs = sq * sk - else: # triangular mask - if sq > sk: - valid_pairs = sk * (sk + 1) // 2 + (sq - sk) * sk - else: # sq <= sk - valid_pairs = sq * (sq + 1) // 2 - - # one matmul FLOPs (mul + add) = 2 · m · n · k - flops_per_matmul = 2.0 * b * hq * valid_pairs * d_head - total_flops = 2.0 * flops_per_matmul # 2 matmuls in forward - - if mode == "fwd": - pass - elif mode == "bwd": - total_flops *= 2.5 # 2·bwd + 0.5·recompute - elif mode == "full": - total_flops *= 3.5 # fwd + bwd - else: - raise ValueError(f"unknown mode {mode}") - - return total_flops - -def add_tflops_columns(df: pd.DataFrame, func_cfg: FunctionConfig) -> pd.DataFrame: - ms_col = func_cfg.column_name() - tf_col = ms_col.replace("_ms", "_tflops") - flops = df.apply(_flops_single_row, axis=1, mode=func_cfg.mode) - df[tf_col] = flops / df[ms_col] * 1e-9 - return df - -def generate_output_filename(function_configs, output_type, output_format): - # create a timestamp - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - - # simple filename format - base_filename = f"benchmark_{timestamp}" - - if output_format == "csv": - return base_filename + ".csv" - else: # markdown - return base_filename + ".md" - -def main(): - """ - Main function to run benchmarks. - """ - global VERBOSE - - # check environment variables - check_environment_variables() - - # start timing the entire benchmarking process - total_start_time = time.time() - - # process args to get function configs and input configs - function_configs, all_input_configs, output_type, output_format = process_args() - - # Print summary of what will be benchmarked (always show this) - print(f"\nBenchmarking {len(function_configs)} configuration(s):") - unique_fns = set(fc.fn_name for fc in function_configs) - print(f" Functions: {', '.join(unique_fns)}") - unique_backends = set(fc.backend for fc in function_configs) - print(f" Backends: {', '.join(unique_backends)}") - unique_modes = set(fc.mode for fc in function_configs) - print(f" Modes: {', '.join(unique_modes)}") - print() - - # run benchmarks for each function configuration - combined_ms_df = None - combined_tf_df = None - input_cols = ["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"] - - for i, func_config in enumerate(function_configs, 1): - # Progress indicator - if not VERBOSE: - print(f"[{i}/{len(function_configs)}] ", end='') - - # run benchmark with the input configs for this function config - input_configs = all_input_configs[func_config] - df, elapsed_time = run_benchmark(func_config, input_configs) - - if VERBOSE: - print(f"Total time for benchmarking {func_config.fn_name} in {func_config.mode} mode with {func_config.dtype}: {elapsed_time:.2f} seconds") - - # add to combined table - df = add_tflops_columns(df, func_config) - ms_cols = [c for c in df.columns if c.endswith('_ms')] - tf_cols = [c for c in df.columns if c.endswith('_tflops')] - - ms_df = df[input_cols + ms_cols] - tf_df = df[input_cols + tf_cols] - - if combined_ms_df is None: - combined_ms_df = ms_df - combined_tf_df = tf_df - else: - combined_ms_df = combined_ms_df.merge(ms_df, on=input_cols, how="outer") - combined_tf_df = combined_tf_df.merge(tf_df, on=input_cols, how="outer") - - # print new line to seperate the combined data information from the benchmark specific information - print() - - # print total time for all benchmarks - total_elapsed_time = time.time() - total_start_time - print(f"Total benchmark time: {total_elapsed_time:.1f} seconds") - - # save combined data and make comparisons if we have multiple function configs - has_multiple_func_configs = False # len(function_configs) > 1 - if has_multiple_func_configs: - if len(function_configs) == 2: - func1 = function_configs[0] - func2 = function_configs[1] - - # construct column names for the timing results - col1 = func1.column_name() - col2 = func2.column_name() - - # Check if we're comparing triton vs ck (in either order) - is_triton_vs_ck = ( - (func1.backend == "triton" and func2.backend == "ck") or - (func1.backend == "ck" and func2.backend == "triton") - ) - - # For triton vs ck comparisons - if is_triton_vs_ck: - # For triton vs ck comparisons, always make triton the baseline - if func1.backend == "triton" and func2.backend == "ck": - triton_col = col1 - ck_col = col2 - ratio_col = f"ck_to_triton_ratio" - else: - triton_col = col2 - ck_col = col1 - ratio_col = f"ck_to_triton_ratio" - - # Calculate ratio: ck_time / triton_time (values > 1 mean triton is faster) - combined_ms_df[ratio_col] = combined_ms_df[ck_col] / combined_ms_df[triton_col] - - # print explanation - print(f"Comparison Results (triton vs ck):") - print(f"Ratio values: values > 1 mean triton is faster (by that factor), values < 1 mean ck is faster") - - # output based on selected metric - if output_type == "ms": - if combined_ms_df is not None: - filename = generate_output_filename(function_configs, "ms", output_format) - print(f"\nCombined wall-time (ms) table:") - print(combined_ms_df) - - if output_format == "csv": - combined_ms_df.to_csv(filename, index=False) - print(f"Results saved to: {filename}") - else: # markdown - with open(filename, 'w') as f: - f.write(combined_ms_df.to_markdown(index=False, floatfmt=".2f")) - print(f"Results saved to: {filename}") - else: # output_type == "tflops" - if combined_tf_df is not None: - filename = generate_output_filename(function_configs, "tflops", output_format) - print(f"\nCombined throughput (TFLOPs) table:") - print(combined_tf_df) - - if output_format == "csv": - combined_tf_df.to_csv(filename, index=False) - print(f"Results saved to: {filename}") - else: # markdown - with open(filename, 'w') as f: - f.write(combined_tf_df.to_markdown(index=False, floatfmt=".2f")) - print(f"Results saved to: {filename}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py new file mode 100755 index 00000000000..085232cedc5 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -0,0 +1,4997 @@ +import os +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from typing import Literal, Optional +from .utils import ( + DEBUG, + DROPOUT_USE_PYTORCH, + DROPOUT_DUMP, + compute_fp8_scaling_factors, + create_dropout_mask, + create_dropout_mask_varlen, + is_cdna, + is_fp8, +) + +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) +tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) + + +def get_autotune_configs(): + if False: + if is_cdna(): + # shared meta-parameters + NUM_STAGES = 1 + NUM_WARPS = 4 + WAVES_PER_EU = 2 + MATRIX_INSTR_NONKDIM = 16 + + preprocess_autotune_configs = [ + triton.Config( + { + "PRE_BLOCK": 128, + "waves_per_eu": WAVES_PER_EU, + "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), # og config + triton.Config( + { + "PRE_BLOCK": 64, + "waves_per_eu": WAVES_PER_EU, + "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), + triton.Config( + { + "PRE_BLOCK": 32, + "waves_per_eu": WAVES_PER_EU, + "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), + triton.Config( + { + "PRE_BLOCK": 16, + "waves_per_eu": WAVES_PER_EU, + "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), + ] + preprocess_autotune_keys = [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM_QK", + "ACTUAL_HEAD_DIM_V", + "IS_VARLEN", + "HQ", + "HK", + ] + causal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": WAVES_PER_EU, + "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), # og config + triton.Config( + { + "BLOCK_M1": 16, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 16, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": WAVES_PER_EU, + "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), + triton.Config( + { + "BLOCK_M1": 16, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 16, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": WAVES_PER_EU, + "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": WAVES_PER_EU, + "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), + ] + causal_autotune_keys = [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM_QK", + "ACTUAL_HEAD_DIM_V", + "IS_VARLEN", + "HQ", + "HK", + ] + noncausal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": WAVES_PER_EU, + "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), # og config + triton.Config( + { + "BLOCK_M1": 16, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 16, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": WAVES_PER_EU, + "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), + triton.Config( + { + "BLOCK_M1": 16, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 16, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": WAVES_PER_EU, + "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": WAVES_PER_EU, + "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), + ] + noncausal_autotune_keys = [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM_QK", + "ACTUAL_HEAD_DIM_V", + "IS_VARLEN", + "HQ", + "HK", + ] + + return ( + (preprocess_autotune_configs, preprocess_autotune_keys), + (causal_autotune_configs, causal_autotune_keys), + (noncausal_autotune_configs, noncausal_autotune_keys), + ) + else: + raise ValueError("Unknown Device Type") + else: + # meta-parameters + # TODO: fix num_stages later + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + PRE_BLOCK = 128 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + + assert BLOCK_N1 == BLOCK_M2 + + # configs for the kernels + preprocess_autotune_configs = [ + triton.Config( + {"PRE_BLOCK": PRE_BLOCK, "waves_per_eu": WAVES_PER_EU}, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), + ] + preprocess_autotune_keys = [ + "max_seqlen_q", + "ACTUAL_HEAD_DIM_V", + "IS_VARLEN", + ] + causal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": BLOCK_M1, + "BLOCK_N1": BLOCK_N1, + "BLOCK_M2": BLOCK_M2, + "BLOCK_N2": BLOCK_N2, + "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, + "waves_per_eu": WAVES_PER_EU, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), + ] + causal_autotune_keys = [ + "dropout_p", + "max_seqlen_q", + "max_seqlen_k", + "ACTUAL_HEAD_DIM_QK", + "ACTUAL_HEAD_DIM_V", + "IS_VARLEN", + "HQ", + "HK", + ] + noncausal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": BLOCK_M1, + "BLOCK_N1": BLOCK_N1, + "BLOCK_M2": BLOCK_M2, + "BLOCK_N2": BLOCK_N2, + "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, + "waves_per_eu": WAVES_PER_EU, + }, + num_stages=NUM_STAGES, + num_warps=NUM_WARPS, + ), + ] + noncausal_autotune_keys = [ + "dropout_p", + "max_seqlen_q", + "max_seqlen_k", + "ACTUAL_HEAD_DIM_QK", + "ACTUAL_HEAD_DIM_V", + "IS_VARLEN", + "HQ", + "HK", + ] + return ( + (preprocess_autotune_configs, preprocess_autotune_keys), + (causal_autotune_configs, causal_autotune_keys), + (noncausal_autotune_configs, noncausal_autotune_keys), + ) + + +( + (preprocess_autotune_configs, preprocess_autotune_keys), + (causal_autotune_configs, causal_autotune_keys), + (noncausal_autotune_configs, noncausal_autotune_keys), +) = get_autotune_configs() + + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at +@triton.jit +def _bwd_fused_atomics_preprocess( + o_ptr, + do_ptr, # noqa: E741 + delta_ptr, + stride_o_b, + stride_o_h, + stride_o_m, + stride_o_k, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_descale_do_z, + cu_seqlens_q, + max_seqlen_q, + descale_do_ptr, + BLOCK_M: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, +): + pid_m = tl.program_id(0) # seqlen + bid = tl.program_id(1) # batch + hid = tl.program_id(2) # head + + # Handle varlen + q_start = 0 + seqlen_q = max_seqlen_q + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # Offset O/DO by batch, head and q_start + offs = ( + bid * stride_o_b + + hid * stride_o_h + + q_start * stride_o_m + + offs_m[:, None] * stride_o_m + + offs_k[None, :] * stride_o_k + ) + + # create masks + mask_m = offs_m < seqlen_q + mask = mask_m[:, None] + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask &= offs_k[None, :] < BLOCK_D_MODEL + + # load [BLOCK_M, BLOCK_D_MODEL_POW2] + o = tl.load(o_ptr + offs, mask=mask, other=0.0) + do = tl.load(do_ptr + offs, mask=mask, other=0.0) + + # compute and write-back to delta + if IS_FP8: + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hid) + + # NOTE: do is in the fp8 range and o is not in fp8 + delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) + else: + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + + offs_delta = ( + bid * stride_delta_b + + hid * stride_delta_h + + q_start * stride_delta_m + + offs_m * stride_delta_m + ) + tl.store(delta_ptr + offs_delta, delta, mask=mask_m) + + +@triton.jit +def _bwd_fused_atomics_dq_inner( + dq, + q, + K, + V, + do, + m, + Delta, + sm_scale, + stride_qm, + stride_qk, + stride_kn, + stride_kk, + stride_vn, + stride_vk, + stride_dropout_m, + stride_dropout_n, + stride_deltam, + seqlen_q, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + RCP_LN2: tl.constexpr = 1.4426950408889634 + + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk + + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) + + curr_n = start_n + step_n = BLOCK_N + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + for blk_idx in range(num_steps): + offs_n = curr_n + tl.arange(0, BLOCK_N) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + mask_kT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD: + mask_kT &= offs_k[:, None] < BLOCK_D_MODEL + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + + # dropout + if ENABLE_DROPOUT: + philox_offs = ( + curr_philox_offset + + offs_m[:, None] * stride_dropout_m + + offs_n[None, :] * stride_dropout_n + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + # qk + if IS_FP8: + qk = tl.dot(q, kT) * descale_q * descale_k + else: + qk = tl.dot(q, kT) + p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) + + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask * mask_mn + p = tl.where(mask, p, 0.0) + + # dp + if IS_FP8: + dp = tl.dot(do, vT) * descale_do * descale_v + else: + dp = tl.dot(do, vT) + + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + + # ds + delta_i = Di[:, None] + ds = p * (dp - delta_i) + + # dq + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + if IS_FP8: + scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) + dq += ( + tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) + * descale_ds + * descale_k + ) + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_vn + return dq + + +@triton.jit +def _bwd_fused_atomics_dkdv_inner( + dk, + dv, + Q, + k, + v, + DO, + M, + D, + sm_scale, + stride_q_m, + stride_q_k, + stride_do_m, + stride_do_k, + stride_dropout_m, + stride_dropout_n, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + qT_ptrs = ( + Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k + ) # [BLOCK_D_MODEL_POW2, BLOCK_M] + do_ptrs = DO + offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # Iterate over blocks(BLOCK_M size) of Q while calculating + # a fixed block(BLOCK_N) of dk and dv. Note, during backward + # pass P has to be recomputed. However, this kernel computes + # dV and dK, so we compute we need P^T and S^T. See backward pass + # equations + # + # From Flash Attention Paper: + # ForwardPass: S = QkT, P=softmax(S), O=PV + # + # BackwardPass equations + # dV = P^TdO + # dP = dOV^T + # dS = dsoftmax(dP) + # dQ = dSK + # dK = QdS^T + for blk_idx in range(num_steps): + offs_m = curr_m + tl.arange(0, BLOCK_M) + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < BLOCK_D_MODEL + mask_do &= offs_k[None, :] < BLOCK_D_MODEL + + # load qT + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + + # dropout + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[None, :] * stride_dropout_m + + offs_n[:, None] * stride_dropout_n + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + + # Load M + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + + # Compute qkT + if IS_FP8: + qkT = tl.dot(k, qT) * descale_q * descale_k + else: + qkT = tl.dot(k, qT) + + # Compute pT(use m and also apply sm_scale) + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + if MASK: + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + pT = tl.where(mask, pT, 0.0) + + # load DO + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + + # dV + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + if IS_FP8: + scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors( + pT_dropout, FP8_MAX + ) + dv += ( + tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) + * descale_p_dropout + * descale_do + ) + else: + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + if IS_FP8: + scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) + dv += ( + tl.dot((pT * scale_pT).to(do.type.element_ty), do) + * descale_pT + * descale_do + ) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + # Load delta + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + + # Compute dP and dS + if IS_FP8: + dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + else: + dpT = tl.dot(v, tl.trans(do)) + + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + + # compute dk + if IS_FP8: + scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += ( + tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) + * descale_dsT + * descale_q + ) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + + # increment pointers + curr_m += step_m + qT_ptrs += step_m * stride_q_m + do_ptrs += step_m * stride_do_m + + return dk, dv + + +@triton.jit +def _bwd_fused_atomics_dkdvdq_inner( + dk, + dv, + Q, + k, + v, + DO, + DQ, + M, + D, + sm_scale, + stride_q_m, + stride_q_k, + stride_dq_m, + stride_dq_k, + stride_do_m, + stride_do_k, + stride_dropout_m, + stride_dropout_n, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + workgroup_id: tl.int32, +): + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + + qT_ptrs_start = ( + Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k + ) # [BLOCK_D_MODEL_POW2, BLOCK_M] + dq_ptrs_start = ( + DQ + offs_m[:, None] * stride_dq_m + offs_k[None, :] * stride_dq_k + ) # [BLOCK_M, BLOCK_D_MODEL_POW2] + + do_ptrs_start = DO + offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # Iterate over blocks(BLOCK_M size) of Q while calculating + # a fixed block(BLOCK_N) of dk and dv. Note, during backward + # pass P has to be recomputed. However, this kernel computes + # dV and dK, so we compute we need P^T and S^T. See backward pass + # equations + # + # From Flash Attention Paper: + # ForwardPass: S = QkT, P=softmax(S), O=PV + # + # BackwardPass equations + # dV = P^TdO + # dP = dOV^T + # dS = dsoftmax(dP) + # dQ = dSK + # dK = QdS^T + + # Compute a starting index and step based on workgroup_id + # Use a simple hash-like function to spread out the starting points + start_idx = ( + workgroup_id * 17 + ) % num_steps # 17 is an arbitrary prime to spread indices + # Ensure step is coprime with num_steps to visit all indices exactly once + step = 1 # 3 if num_steps > 1 or num_steps==3 else 1 # coprime with num_steps + + for iter in range(num_steps): + # Compute the permuted block index + blk_idx = (start_idx + iter * step) % num_steps + + curr_m = start_m + blk_idx * step_m + qT_ptrs = qT_ptrs_start + blk_idx * step_m * stride_q_m + dq_ptrs = dq_ptrs_start + blk_idx * step_m * stride_dq_m + do_ptrs = do_ptrs_start + blk_idx * step_m * stride_do_m + + offs_m = curr_m + tl.arange(0, BLOCK_M) + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < BLOCK_D_MODEL + mask_do &= offs_k[None, :] < BLOCK_D_MODEL + + # load qT + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + + # dropout + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[None, :] * stride_dropout_m + + offs_n[:, None] * stride_dropout_n + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + + # Load M + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + + # Compute qkT + if IS_FP8: + qkT = tl.dot(k, qT) * descale_q * descale_k + else: + qkT = tl.dot(k, qT) + + # Compute pT(use m and also apply sm_scale) + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + if MASK: + causal_mask = (offs_m[None, :] - delta_qk) >= (offs_n[:, None]) + mask = causal_mask & mask_nm + pT = tl.where(mask, pT, 0.0) + + # load DO + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + + # dV + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + if IS_FP8: + scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors( + pT_dropout, FP8_MAX + ) + dv += ( + tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) + * descale_p_dropout + * descale_do + ) + else: + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + if IS_FP8: + scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) + dv += ( + tl.dot((pT * scale_pT).to(do.type.element_ty), do) + * descale_pT + * descale_do + ) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + # Load delta + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + + # Compute dP and dS + if IS_FP8: + dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + else: + dpT = tl.dot(v, tl.trans(do)) + + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + + # compute dk + if IS_FP8: + scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += ( + tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) + * descale_dsT + * descale_q + ) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + + # We can compute the dq_partial here and do a atomic add to the correct memory location + # NOTE: Possible problems with the atomic add: contention, is inside a loop which has achieved bad perf before + # (BLOCK_M, BLOCK_N) x (BLOCK_N, D) + if IS_FP8: + dq_partial = ( + tl.dot((dsT * scale_dsT).to(k.dtype).T, k) * descale_dsT * descale_k + ) + else: + dq_partial = tl.dot(dsT.to(k.dtype).T, k) + tl.atomic_add( + dq_ptrs, + dq_partial * sm_scale, + mask=mask_m[:, None], + sem="relaxed", + ) + + return dk, dv + + +@triton.jit +def _bwd_kernel_fused_atomics_dkdvdq_causal( + q_ptr, + k_ptr, + v_ptr, + sm_scale, + do_ptr, + dk_ptr, + dv_ptr, + dq_ptr, + m_ptr, + delta_ptr, + stride_q_b, + stride_q_h, + stride_q_m, + stride_q_k, + stride_k_b, + stride_k_h, + stride_k_n, + stride_k_k, + stride_v_b, + stride_v_h, + stride_v_n, + stride_v_k, + stride_dk_b, + stride_dk_h, + stride_dk_n, + stride_dk_k, + stride_dq_b, + stride_dq_h, + stride_dq_m, + stride_dq_k, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_do_b, + stride_do_h, + stride_do_m, + stride_do_k, + stride_dropout_b, + stride_dropout_h, + stride_dropout_m, + stride_dropout_n, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_descale_do_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BATCH, + NUM_K_PIDS, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + wid = tl.program_id(0) # workgoup id: 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 + + # workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim + batch_idx = wid % BATCH + head_k_idx = wid // BATCH % NUM_K_HEADS + seq_k_blk_idx = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS + + # Determine q and k start along with seqlen_q and seqlen_k + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + else: + start_delta = start_delta_q_lt_k + + start_n = seq_k_blk_idx * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_kv &= mask_k[None, :] + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + batch_idx * stride_k_b + + head_k_idx * stride_k_h + + k_start * stride_k_n + + offs_n[:, None] * stride_k_n + + offs_k[None, :] * stride_k_k + ) + adj_v = ( + batch_idx * stride_v_b + + head_k_idx * stride_v_h + + k_start * stride_v_n + + offs_n[:, None] * stride_v_n + + offs_k[None, :] * stride_v_k + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(k_ptr + adj_k, mask=mask_kv, other=0.0) + v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) + + # If MQA / GQA, set the K and V head offsets appropriately. + for head_q_idx in range( + head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE + ): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = (start_m // BLOCK_M) * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + adj_dq = ( + batch_idx * stride_dq_b + head_q_idx * stride_dq_h + q_start * stride_dq_m + ) + + q_ptr_adj = q_ptr + adj_q + dq_ptr_adj = dq_ptr + adj_dq + + adj_do = ( + batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + ) + do_ptr_adj = do_ptr + adj_do + adj_delta = ( + batch_idx * stride_delta_b + + head_q_idx * stride_delta_h + + q_start * stride_delta_m + ) + m_ptr_adj = m_ptr + adj_delta + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + dropout_offset = ( + dropout_mask + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + + # when q < k, we may skip the initial masked op + # if seq_k_blk_idx < num_blocks_skip: + # num_steps = 0 + + if IS_FP8: + descale_q = tl.load( + descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx + ) + descale_k = tl.load( + descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx + ) + descale_v = tl.load( + descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx + ) + descale_do = tl.load( + descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx + ) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # if unaligned start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + dk, dv = _bwd_fused_atomics_dkdvdq_inner( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + dq_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_dq_m, + stride_dq_k, # strides for q + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, # fp8 descale factors from user + MASK_BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=seq_k_blk_idx, + ) + + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + end_m = start_m + num_steps * BLOCK_M + + dk, dv = _bwd_fused_atomics_dkdvdq_inner( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + dq_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_dq_m, + stride_dq_k, # strides for dq + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, # fp8 descale factors from user + BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=seq_k_blk_idx, + ) + + # Write back dV and dK. + offs_dkdv = ( + batch_idx * stride_dk_b + + head_k_idx * stride_dk_h + + k_start * stride_dk_n + + offs_n[:, None] * stride_dk_n + + offs_k[None, :] * stride_dk_k + ) + tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_fused_atomics_dkdv_causal( + q_ptr, + k_ptr, + v_ptr, + sm_scale, + do_ptr, + dk_ptr, + dv_ptr, + m_ptr, + delta_ptr, + stride_q_b, + stride_q_h, + stride_q_m, + stride_q_k, + stride_k_b, + stride_k_h, + stride_k_n, + stride_k_k, + stride_v_b, + stride_v_h, + stride_v_n, + stride_v_k, + stride_dk_b, + stride_dk_h, + stride_dk_n, + stride_dk_k, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_do_b, + stride_do_h, + stride_do_m, + stride_do_k, + stride_dropout_b, + stride_dropout_h, + stride_dropout_m, + stride_dropout_n, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_descale_do_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + # seq block, batch, head_k + seq_k_blk_idx = tl.program_id(0) + batch_idx = tl.program_id(1) + head_k_idx = tl.program_id(2) + + # Determine q and k start along with seqlen_q and seqlen_k + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + else: + start_delta = start_delta_q_lt_k + + start_n = seq_k_blk_idx * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_kv &= mask_k[None, :] + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + batch_idx * stride_k_b + + head_k_idx * stride_k_h + + k_start * stride_k_n + + offs_n[:, None] * stride_k_n + + offs_k[None, :] * stride_k_k + ) + adj_v = ( + batch_idx * stride_v_b + + head_k_idx * stride_v_h + + k_start * stride_v_n + + offs_n[:, None] * stride_v_n + + offs_k[None, :] * stride_v_k + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(k_ptr + adj_k, mask=mask_kv, other=0.0) + v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) + + # If MQA / GQA, set the K and V head offsets appropriately. + for head_q_idx in range( + head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE + ): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + q_ptr_adj = q_ptr + adj_q + adj_do = ( + batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + ) + do_ptr_adj = do_ptr + adj_do + adj_delta = ( + batch_idx * stride_delta_b + + head_q_idx * stride_delta_h + + q_start * stride_delta_m + ) + m_ptr_adj = m_ptr + adj_delta + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + dropout_offset = ( + dropout_mask + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + # when q < k, we may skip the initial masked op + if seq_k_blk_idx < num_blocks_skip: + num_steps = 0 + + if IS_FP8: + descale_q = tl.load( + descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx + ) + descale_k = tl.load( + descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx + ) + descale_v = tl.load( + descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx + ) + descale_do = tl.load( + descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx + ) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + dk, dv = _bwd_fused_atomics_dkdv_inner( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, # fp8 descale factors from user + MASK_BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + end_m = start_m + num_steps * BLOCK_M + + dk, dv = _bwd_fused_atomics_dkdv_inner( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, # fp8 descale factors from user + BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + # Write back dV and dK. + offs_dkdv = ( + batch_idx * stride_dk_b + + head_k_idx * stride_dk_h + + k_start * stride_dk_n + + offs_n[:, None] * stride_dk_n + + offs_k[None, :] * stride_dk_k + ) + tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_fused_atomics_dq_causal( + q_ptr, + k_ptr, + v_ptr, + sm_scale, + do_ptr, + dq_ptr, + m_ptr, + delta_ptr, + stride_q_b, + stride_q_h, + stride_q_m, + stride_q_k, + stride_k_b, + stride_k_h, + stride_k_n, + stride_k_k, + stride_v_b, + stride_v_h, + stride_v_n, + stride_v_k, + stride_dq_b, + stride_dq_h, + stride_dq_m, + stride_dq_k, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_do_b, + stride_do_h, + stride_do_m, + stride_do_k, + stride_dropout_b, + stride_dropout_h, + stride_dropout_m, + stride_dropout_n, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_descale_do_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + seq_q_blk_idx = tl.program_id(0) + batch_idx = tl.program_id(1) + head_k_idx = tl.program_id(2) + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + # Figure out causal starting block since we have seqlen_q <=> seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + # DQ tiles on M dim and iterate on N dim, so we there could be some tiles we + # can simply skip and we need to adjust starting position. + start_m = seq_q_blk_idx * BLOCK_M + # seqlen_q > seqlen_k, no need to process these tile for dq + delta_qk = seqlen_q - seqlen_k + if start_m + BLOCK_M < delta_qk: + return + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_m = start_m + tl.arange(0, BLOCK_M) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_q_m + offs_k[None, :] * stride_q_k + offs_do = offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k + adj_k = batch_idx * stride_k_b + head_k_idx * stride_k_h + k_start * stride_k_n + adj_v = batch_idx * stride_v_b + head_k_idx * stride_v_h + k_start * stride_v_n + k_ptr_adj = k_ptr + v_ptr_adj = v_ptr + k_ptr_adj += adj_k + v_ptr_adj += adj_v + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + for head_q_idx in range( + head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE + ): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + adj_do = ( + batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + ) + adj_delta = ( + batch_idx * stride_delta_b + + head_q_idx * stride_delta_h + + q_start * stride_delta_m + ) + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + dropout_offset = ( + dropout_mask + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + + q = tl.load(q_ptr + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(do_ptr + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(m_ptr + adj_delta + offs_m * stride_delta_m, mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N: tl.constexpr = BLOCK_N // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N) + + if IS_FP8: + descale_q = tl.load( + descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx + ) + descale_k = tl.load( + descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx + ) + descale_v = tl.load( + descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx + ) + descale_do = tl.load( + descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx + ) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _bwd_dq_inner, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + dq = _bwd_fused_atomics_dq_inner( + dq, + q, + k_ptr_adj, + v_ptr_adj, + do, + m, + delta_ptr_adj, + sm_scale, + stride_q_m, + stride_q_k, + stride_k_n, + stride_k_k, + stride_v_n, + stride_v_k, + stride_dropout_m, + stride_dropout_n, + stride_delta_m, + seqlen_q, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + BLOCK_M, + MASK_BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=True, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + end_n -= num_steps * MASK_BLOCK_N + num_steps = tl.cdiv(end_n, BLOCK_N) + start_n = max(end_n - num_steps * BLOCK_N, 0) + dq = _bwd_fused_atomics_dq_inner( + dq, + q, + k_ptr_adj, + v_ptr_adj, + do, + m, + delta_ptr_adj, + sm_scale, + stride_q_m, + stride_q_k, + stride_k_n, + stride_k_k, + stride_v_n, + stride_v_k, + stride_dropout_m, + stride_dropout_n, + stride_delta_m, + seqlen_q, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + BLOCK_M, + BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + # Write back dQ. + offs_dq = ( + batch_idx * stride_dq_b + + head_q_idx * stride_dq_h + + q_start * stride_dq_m + + offs_m[:, None] * stride_dq_m + + offs_k[None, :] * stride_dq_k + ) + dq *= sm_scale + tl.store(dq_ptr + offs_dq, dq, mask=mask_q) + + +@triton.jit +def _bwd_kernel_fused_atomics_dkdvdq_noncausal( + Q, + K, + V, + sm_scale, + DO, + DK, + DV, + DQ, + M, + Delta, + stride_qb, + stride_qh, + stride_qm, + stride_qk, + stride_kb, + stride_kh, + stride_kn, + stride_kk, + stride_vb, + stride_vh, + stride_vn, + stride_vk, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkk, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqk, + stride_deltab, + stride_deltah, + stride_deltam, + stride_dob, + stride_doh, + stride_dom, + stride_dok, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_descale_do_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BATCH, + NUM_K_PIDS, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + # workgroup id + wid = tl.program_id(0) # 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 + + # Workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim + # This is in order to avoid contention for the tl.atomic_add (inside _bwd_dkdvdq_inner) that happens between workgroups that share the same batch and head_k. + bid = wid % BATCH + hkid = wid // BATCH % NUM_K_HEADS + pid = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_kv &= offs_k < BLOCK_D_MODEL + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_k[None, :] * stride_kk + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_k[None, :] * stride_vk + ) + + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + + Q_ptr = Q + adj_q + DQ_ptr = DQ + adj_dq + + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # dropout + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + + dk, dv = _bwd_fused_atomics_dkdvdq_inner( + dk, + dv, + Q_ptr, + k, + v, + DO_ptr, + DQ_ptr, + M_ptr, + Delta_ptr, + sm_scale, + stride_qm, + stride_qk, + stride_dqm, + stride_dqk, + stride_dom, + stride_dok, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + BLOCK_M, + BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=pid, + ) + + adj_dkdv = ( + bid * stride_dkb + + hkid * stride_dkh + + k_start * stride_dkn + + offs_n[:, None] * stride_dkn + + offs_k[None, :] * stride_dkk + ) + tl.store(DV + adj_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_fused_atomics_dkdv_noncausal( + Q, + K, + V, + sm_scale, + DO, + DK, + DV, + M, + Delta, + stride_qb, + stride_qh, + stride_qm, + stride_qk, + stride_kb, + stride_kh, + stride_kn, + stride_kk, + stride_vb, + stride_vh, + stride_vn, + stride_vk, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkk, + stride_deltab, + stride_deltah, + stride_deltam, + stride_dob, + stride_doh, + stride_dom, + stride_dok, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_descale_do_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_kv &= offs_k < BLOCK_D_MODEL + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_k[None, :] * stride_kk + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_k[None, :] * stride_vk + ) + + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # dropout + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + dk, dv = _bwd_fused_atomics_dkdv_inner( + dk, + dv, + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, + stride_qm, + stride_qk, + stride_dom, + stride_dok, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + BLOCK_M, + BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + adj_dkdv = ( + bid * stride_dkb + + hkid * stride_dkh + + k_start * stride_dkn + + offs_n[:, None] * stride_dkn + + offs_k[None, :] * stride_dkk + ) + tl.store(DV + adj_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_fused_atomics_dq_noncausal( + Q, + K, + V, + sm_scale, + DO, + DQ, + M, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qk, + stride_kb, + stride_kh, + stride_kn, + stride_kk, + stride_vb, + stride_vh, + stride_vn, + stride_vk, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqk, + stride_deltab, + stride_deltah, + stride_deltam, + stride_dob, + stride_doh, + stride_dom, + stride_dok, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_descale_do_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + pid = tl.program_id(0) # seqlen + bid = tl.program_id(1) # batch + hkid = tl.program_id(2) # head_k + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + start_m = pid * BLOCK_M + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_m = start_m + tl.arange(0, BLOCK_M) + + # mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + K += adj_k + V += adj_v + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + delta_ptr = delta + adj_delta + + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) + m = m[:, None] + + # FP8 + if IS_FP8: + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N) + dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dq = _bwd_fused_atomics_dq_inner( + dq, + q, + K, + V, + do, + m, + delta_ptr, + sm_scale, + stride_qm, + stride_qk, + stride_kn, + stride_kk, + stride_vn, + stride_vk, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + seqlen_q, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + BLOCK_M, + BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q) +@triton.autotune( + configs=preprocess_autotune_configs, + key=preprocess_autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def _bwd_preprocess( + O, + DO, # noqa: E741 + Delta, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_descale_do_z, + cu_seqlens_q, + max_seqlen_q, + Descale_do, + PRE_BLOCK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, +): + pid_m = tl.program_id(0) + bid = tl.program_id(1) + hid = tl.program_id(2) + # Handle varlen + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) + offs_d = tl.arange(0, HEAD_DIM_V) + # pointer offsets for O & DO + off_o = ( + bid * stride_ob + + hid * stride_oh + + q_start * stride_om + + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_od + ) # noqa: E741 + off_do = ( + bid * stride_dob + + hid * stride_doh + + q_start * stride_dom + + offs_m[:, None] * stride_dom + + offs_d[None, :] * stride_dod + ) + + # create masks + mask_m = offs_m < seqlen_q + mask_md = mask_m[:, None] + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + if PADDED_HEAD_V: + mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM_V + # load + o = tl.load(O + off_o, mask=mask_md, other=0.0) + do = tl.load(DO + off_do, mask=mask_md, other=0.0) + # compute and write-back to delta + if IS_FP8: + off_descale_do = bid * stride_descale_do_z + hid + descale_do = tl.load(Descale_do + off_descale_do) + + # NOTE: do is in the fp8 range and o is not in fp8 + delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) + else: + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + off_delta = ( + bid * stride_delta_b + + hid * stride_delta_h + + q_start * stride_delta_m + + offs_m * stride_delta_m + ) + tl.store(Delta + off_delta, delta, mask=mask_m) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _bwd_dkdv_inner( + dk, + dv, # output + Q, + k, + v, + DO, + M, + D, + sm_scale, # input tensor + stride_qm, + stride_qk, + stride_dom, + stride_dok, + stride_dropoutm, + stride_dropoutn, + stride_lse_m, + stride_delta_m, + BLOCK_M: tl.constexpr, # 16 + BLOCK_N: tl.constexpr, # 128 + HEAD_DIM_QK: tl.constexpr, # + HEAD_DIM_V: tl.constexpr, # + ACTUAL_HEAD_DIM_QK: tl.constexpr, # + ACTUAL_HEAD_DIM_V: tl.constexpr, # + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + # Filled in by the wrapper. + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, # fp8 descale factors from user + MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal + ENABLE_DROPOUT: tl.constexpr, # activate dropout + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, # activate exp2 + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD_QK: tl.constexpr = ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) + offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) + offs_k_qk = tl.arange(0, HEAD_DIM_QK) + offs_k_v = tl.arange(0, HEAD_DIM_V) + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + # Q and DO are (seqlen_q, head_dim) + # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM_QK, 1), transpose of q + qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k_qk[:, None] * stride_qk + # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM_V), NOT transposed + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k_v[None, :] * stride_dok + # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N % BLOCK_M == 0) + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + + for blk_idx in range(num_steps): + if DEBUG_TRITON: + print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 + offs_m = curr_m + tl.arange(0, BLOCK_M) + # update the mask because offs_m advanced + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD_QK: + mask_qT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK + if PADDED_HEAD_V: + mask_do &= offs_k_v[None, :] < ACTUAL_HEAD_DIM_V + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + # generate dropout mask + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[None, :] * stride_dropoutm + + offs_n[:, None] * stride_dropoutn + ) + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = ( + offs_m[None, :] * stride_dropoutm + + offs_n[:, None] * stride_dropoutn + ) + dropout_mask = tl.load(curr_dropout_offset + dropout_offs, mask=mask_nm) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + # Load m before computing qk to reduce pipeline stall. + m = tl.load(M + offs_m * stride_lse_m, mask=mask_m, other=0.0) + if IS_FP8: + qkT = tl.dot(k, qT) * descale_q * descale_k + else: + qkT = tl.dot(k, qT) + qkT_scaled = qkT * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_n[:, None] + seqlen_q - seqlen_k - offs_m[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qkT_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"qT: {qT.shape}\n", qT) + print(f"k: {k.shape}\n", k) + print(f"qkT scaled: {qkT.shape}\n", qkT_scaled) + # TODO: remove the scaling of m later when we removed re-scaling in fwd + if USE_EXP2: + pT = tl.math.exp2(qkT_scaled * RCP_LN2 - m[None, :] * RCP_LN2) + else: + pT = tl.math.exp(qkT_scaled - m[None, :]) + + # Autoregressive masking. + if MASK: + # offset offs_m with delta_qk since the causal mask starts at + # bottom right of the (seqlen_q, seqlen_k) matrix + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"causal_mask: {causal_mask.shape}\n", causal_mask) + print( + f"qkT after causal: {qkT.shape}\n", + tl.where(causal_mask, qkT * sm_scale, 0.0), + ) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + # Compute dV. + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + if IS_FP8: + scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors( + pT_dropout, FP8_MAX + ) + dv += ( + tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) + * descale_p_dropout + * descale_do + ) + else: + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + if IS_FP8: + scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) + dv += ( + tl.dot((pT * scale_pT).to(do.type.element_ty), do) + * descale_pT + * descale_do + ) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"pT: {pT.shape}\n", pT) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m * stride_delta_m, mask=mask_m) + # Compute dP and dS. + if IS_FP8: + dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + else: + dpT = tl.dot(v, tl.trans(do)) + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + if IS_FP8: + scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += ( + tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) + * descale_dsT + * descale_q + ) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_qm + do_ptrs += step_m * stride_dom + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _bwd_dq_inner( + dq, # output + q, + K, + V, + do, + m, + Delta, + sm_scale, # input + # shared by Q/K/V. + stride_qm, + stride_qk, + stride_kn, + stride_kk, + stride_vn, + stride_vk, + stride_dropoutm, + stride_dropoutn, # stride for dropout + stride_lse_m, + stride_delta_m, + seqlen_q, + seqlen_k, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, # + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + # Filled in by the wrapper. + start_m, + start_n, + end_n, + num_steps, # + descale_q, + descale_k, + descale_v, + descale_do, # fp8 descale factors from user + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD_QK: tl.constexpr = ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k_qk = tl.arange(0, HEAD_DIM_QK) + offs_k_v = tl.arange(0, HEAD_DIM_V) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k_qk[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k_v[:, None] * stride_vk + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_delta_m, mask=mask_m, other=0.0) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + for blk_idx in range(num_steps): + if DEBUG_TRITON: + print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 + offs_n = curr_n + tl.arange(0, BLOCK_N2) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + if DEBUG_TRITON_DETAIL: + print( + f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}" + ) # noqa: E701 + if DEBUG_TRITON_DETAIL: + print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 + mask_kT = mask_n[None, :] + mask_vT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD_QK: + mask_kT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK + if PADDED_HEAD_V: + mask_vT &= offs_k_v[:, None] < ACTUAL_HEAD_DIM_V + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_vT, other=0.0) + + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[:, None] * stride_dropoutm + + offs_n[None, :] * stride_dropoutn + ) + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = ( + offs_m[:, None] * stride_dropoutm + + offs_n[None, :] * stride_dropoutn + ) + dropout_mask = tl.load(curr_dropout_offset + dropout_offs, mask=mask_mn) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + if IS_FP8: + qk = tl.dot(q, kT) * descale_q * descale_k + else: + qk = tl.dot(q, kT) + qk_scaled = qk * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qk_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: + print(f"qk scaled: {qk.shape}\n", qk_scaled) # noqa: E701 + if USE_EXP2: + p = tl.math.exp2(qk_scaled * RCP_LN2 - m * RCP_LN2) + else: + p = tl.math.exp(qk_scaled - m) + + # Autoregressive masking. + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask & mask_mn + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + if IS_FP8: + dp = tl.dot(do, vT) * descale_do * descale_v + else: + dp = tl.dot(do, vT) + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + delta_i = Di[:, None] + ds = p * (dp - delta_i) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + if IS_FP8: + scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) + dq += ( + tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) + * descale_ds + * descale_k + ) + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_vn + return dq + + +@triton.autotune( + configs=causal_autotune_configs, + key=causal_autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), batch) + Q, + K, + V, + sm_scale, + DO, + DQ, + DK, + DV, + M, + Delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_descale_do_z, + stride_az, + stride_ah, + HQ, + HK, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Add seqused parameters + max_seqlen_q, + max_seqlen_k, + Dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + Alibi_slopes, + Descale_q, + Descale_k, + Descale_v, + Descale_do, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + USE_SEQUSED: tl.constexpr, # Add flag for seqused + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + hkid = tl.program_id(0) + pid = tl.program_id(1) + bid = tl.program_id(2) + if DEBUG_TRITON: + print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = ( + tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start + ) + seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) + actual_seqlen_k = ( + tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start + ) + seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) + else: + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: + print(f"delta_qk = {delta_qk}") # noqa: E701 + PADDED_HEAD_QK: tl.constexpr = ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + offs_d_qk = tl.arange(0, HEAD_DIM_QK) + offs_d_v = tl.arange(0, HEAD_DIM_V) + GROUP_SIZE: tl.constexpr = HQ // HK + + # align the delta_qk + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + # This section does dk and dv + dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N1 + delta_aligned = (num_blocks_skip + 1) * BLOCK_N1 + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M1 * BLOCK_M1 + if delta_qk >= 0: + start_delta = delta_qk + if DEBUG_TRITON: + print( + f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}" + ) # noqa: E701 + else: + start_delta = start_delta_q_lt_k + if DEBUG_TRITON: + print( + f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}" + ) # noqa: E701 + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_k = offs_n[:, None] < seqlen_k + mask_v = offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_k &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_v &= mask_d_v[None, :] + + # K/V tensors not changed for the group + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_d_qk[None, :] * stride_kd + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_d_v[None, :] * stride_vd + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_k, mask=mask_k, other=0.0) + v = tl.load(V + adj_v, mask=mask_v, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + # hqid = hkid + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N1 + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M1 * BLOCK_M1 + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N1 + residue_m + if DEBUG_TRITON: + print(f"residue_m = {residue_m}") # noqa: E701 + + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = ( + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m + ) + Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M1) + # when q < k, we may skip the initial masked op + if pid < num_blocks_skip: + num_steps = 0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + if DEBUG_TRITON: + print( + f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}" + ) # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, + dv, # output tensors + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, # input tensors + stride_qm, + stride_qd, # strides for q + stride_dom, + stride_dod, # strides for o + stride_dropoutm, + stride_dropoutn, # strides for dropout + stride_lse_m, + stride_delta_m, + MASK_BLOCK_M1, + BLOCK_N1, # block dim + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, # head dim + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + start_m += num_steps * MASK_BLOCK_M1 + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M1) + end_m = start_m + num_steps * BLOCK_M1 + + if DEBUG_TRITON: + print( + f"start_m after Masked step: {start_m}; num_steps: {num_steps}" + ) # noqa: E701 + if DEBUG_TRITON: + print( + f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}" + ) # noqa: E701 + if DEBUG_TRITON: + print("unMasked") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, + dv, # output tensors + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, # input tensors + stride_qm, + stride_qd, # strides for q + stride_dom, + stride_dod, # strides for o + stride_dropoutm, + stride_dropoutn, # strides for dropout + stride_lse_m, + stride_delta_m, + BLOCK_M1, + BLOCK_N1, # block dim + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, # head dim + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # end of GQA/MQA of dkdv + # Write back dV + adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn + offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) + # write back dk + adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn + offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd + dk *= sm_scale + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) + + # This part does dq + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + # seqlen_q > seqlen_k, no need to process these tile for dq + if DEBUG_TRITON: + print( + f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2}" + ) # noqa: E701 + if start_m + BLOCK_M2 < delta_qk: + if DEBUG_TRITON: + print( + f"start_m + BLOCK_M2 = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2} < delta_qk of {delta_qk}" + ) # noqa: E701 + return + + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + mask_do = offs_m[:, None] < seqlen_q + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_q &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_do &= mask_d_v[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod + # NOTE: don't assume that the strides for k and v are the same! + K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn + V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn + + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M2 - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + if DEBUG_TRITON: + print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = ( + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m + ) + Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) + m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M2, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, + K, + V, + do, + m, + Delta_ptr, + sm_scale, + stride_qm, + stride_qd, + stride_kn, + stride_kd, + stride_vn, + stride_vd, + stride_dropoutm, + stride_dropoutn, + stride_lse_m, + stride_delta_m, + seqlen_q, + seqlen_k, + BLOCK_M2, + MASK_BLOCK_N2, + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + MASK=True, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + end_n -= num_steps * MASK_BLOCK_N2 + num_steps = tl.cdiv(end_n, BLOCK_N2) + start_n = max(end_n - num_steps * BLOCK_N2, 0) + if DEBUG_TRITON: + print( + f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}" + ) # noqa: E701 + dq = _bwd_dq_inner( + dq, + q, + K, + V, + do, + m, + Delta_ptr, + sm_scale, + stride_qm, + stride_qd, + stride_kn, + stride_kd, + stride_vn, + stride_vd, + stride_dropoutm, + stride_dropoutn, + stride_lse_m, + stride_delta_m, + seqlen_q, + seqlen_k, + BLOCK_M2, + BLOCK_N2, + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + # end of GQA/MQA of dq + + +@triton.autotune( + configs=noncausal_autotune_configs, + key=noncausal_autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def bwd_kernel_noncausal( + Q, + K, + V, + sm_scale, + DO, + DQ, + DK, + DV, + M, + Delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_descale_do_z, + stride_az, + stride_ah, + HQ, + HK, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Add seqused parameters + max_seqlen_q, + max_seqlen_k, + Dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + Alibi_slopes, + Descale_q, + Descale_k, + Descale_v, + Descale_do, + BLOCK_M1: tl.constexpr, # 32 + BLOCK_N1: tl.constexpr, # 128 + BLOCK_M2: tl.constexpr, # 128 + BLOCK_N2: tl.constexpr, # 32 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + USE_SEQUSED: tl.constexpr, # Add flag for seqused + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + hkid = tl.program_id(0) + pid = tl.program_id(1) + bid = tl.program_id(2) + if DEBUG_TRITON: + print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = ( + tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start + ) + seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) + actual_seqlen_k = ( + tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start + ) + seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) + else: + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + PADDED_HEAD_QK: tl.constexpr = ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + offs_d_qk = tl.arange(0, HEAD_DIM_QK) + offs_d_v = tl.arange(0, HEAD_DIM_V) + GROUP_SIZE: tl.constexpr = HQ // HK + + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_k = offs_n[:, None] < seqlen_k + mask_v = offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_k &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_v &= mask_d_v[None, :] + # NOTE: don't assume that the strides for k and v are the same! + # K/V tensors not changed for the group + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_d_qk[None, :] * stride_kd + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_d_v[None, :] * stride_vd + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_k, mask=mask_k, other=0.0) + v = tl.load(V + adj_v, mask=mask_v, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = ( + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m + ) + Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # because there is no causal, we always start from the beginning + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M1) + dk, dv = _bwd_dkdv_inner( + dk, + dv, # output tensors + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, # input tensors + stride_qm, + stride_qd, # strides for q + stride_dom, + stride_dod, # strides for o + stride_dropoutm, + stride_dropoutn, # strides for dropout + stride_lse_m, + stride_delta_m, + BLOCK_M1, + BLOCK_N1, # block dim + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, # head dim + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + descale_do, # fp8 descale factors from user + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV + adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn + offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) + # write back dk + adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn + offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd + dk *= sm_scale + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) + + # THIS PART DOES DQ + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + mask_do = offs_m[:, None] < seqlen_q + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_q &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_do &= mask_d_v[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod + K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn + V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = ( + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m + ) + Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) + m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) + m = m[:, None] + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # start can only be 0 at minimum + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N2) + + dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, + K, + V, + do, + m, + Delta_ptr, + sm_scale, + stride_qm, + stride_qd, + stride_kn, + stride_kd, + stride_vn, + stride_vd, + stride_dropoutm, + stride_dropoutn, + stride_lse_m, + stride_delta_m, + seqlen_q, + seqlen_k, + BLOCK_M2, + BLOCK_N2, + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + descale_do, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +def is_contiguous(x, name): + if x.is_contiguous(): + return x + else: + print(f"{name} is not contiguous") + return x.contiguous() + + +DEBUG_TRITON: bool = False +DEBUG_TRITON_DETAIL: bool = False + + +def attention_backward_triton_split_fused_no_atomics_impl( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + use_exp2: bool, + # fp8 + descale_q: Optional[torch.Tensor], + descale_k: Optional[torch.Tensor], + descale_v: Optional[torch.Tensor], + descale_o: Optional[torch.Tensor], + descale_do: Optional[torch.Tensor], + descale_dq: Optional[torch.Tensor], + descale_dk: Optional[torch.Tensor], + descale_dv: Optional[torch.Tensor], + # seqused for FA v3 + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, +): + # get params, strides and shape + IS_VARLEN = layout == "thd" + use_dropout = dropout_p > 0.0 + + # common assertions + assert ( + 0.0 <= dropout_p <= 1.0 + ), f"dropout_p must be between 0 and 1, got {dropout_p}" + assert ( + q.device == k.device == v.device == o.device == do.device == softmax_lse.device + ), f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}, do={do.device}, softmax_lse={softmax_lse.device}" + assert ( + q.dtype == k.dtype == v.dtype == do.dtype + ), "q, k, v, do must have the same dtype" + current_device = torch.cuda.current_device() + assert ( + q.is_cuda and q.device.index == current_device + ), f"Device mismatch: Kernel will launch on cuda:{current_device}, but tensors are on {q.device}" + + # get shapes and strides + if IS_VARLEN: + # shape + total_seqlen_q, nheads_q, head_size_q = q.shape + total_seqlen_k, nheads_k, head_size_k = k.shape + total_seqlen_v, nheads_v, head_size_v = v.shape + nheads_lse, total_seqlen_lse = softmax_lse.shape + + # assert shapes + assert ( + total_seqlen_lse == total_seqlen_q + ), f"softmax_lse seqlen {total_seqlen_lse} != q seqlen {total_seqlen_q}" + assert ( + cu_seqlens_q is not None + ), "cu_seqlens_q must be provided for varlen layout" + assert ( + cu_seqlens_k is not None + ), "cu_seqlens_k must be provided for varlen layout" + assert ( + max_seqlen_q is not None + ), "max_seqlen_q must be provided for varlen layout" + assert ( + max_seqlen_k is not None + ), "max_seqlen_k must be provided for varlen layout" + + # assert head dimensions + assert ( + head_size_q == head_size_k + ), f"head sizes must match: q={head_size_q}, k={head_size_k}" + assert ( + nheads_k == nheads_v + ), f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert ( + nheads_q % nheads_k == 0 + ), f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + assert ( + nheads_lse == nheads_q + ), f"softmax_lse heads {nheads_lse} != q heads {nheads_q}" + + # assert output shapes + assert o.shape == ( + total_seqlen_q, + nheads_q, + head_size_v, + ), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" + assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" + assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}" + + # assert cu_seqlens + assert ( + cu_seqlens_q.dtype == torch.int32 + ), f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" + assert ( + cu_seqlens_k.dtype == torch.int32 + ), f"cu_seqlens_k must be int32, got {cu_seqlens_k.dtype}" + assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" + assert cu_seqlens_k[0] == 0, "cu_seqlens_k must start with 0" + assert ( + cu_seqlens_q[-1] == total_seqlen_q + ), f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} != total_seqlen_q {total_seqlen_q}" + assert ( + cu_seqlens_k[-1] == total_seqlen_k + ), f"cu_seqlens_k[-1] {cu_seqlens_k[-1]} != total_seqlen_k {total_seqlen_k}" + + # set vars + batch = len(cu_seqlens_q) - 1 + head_size_qk = head_size_q + + # strides + stride_qb, stride_qm, stride_qh, stride_qd = ( + 0, + q.stride(0), + q.stride(1), + q.stride(2), + ) + stride_kb, stride_kn, stride_kh, stride_kd = ( + 0, + k.stride(0), + k.stride(1), + k.stride(2), + ) + stride_vb, stride_vn, stride_vh, stride_vd = ( + 0, + v.stride(0), + v.stride(1), + v.stride(2), + ) + stride_ob, stride_om, stride_oh, stride_od = ( + 0, + o.stride(0), + o.stride(1), + o.stride(2), + ) + stride_dqb, stride_dqm, stride_dqh, stride_dqd = ( + 0, + dq.stride(0), + dq.stride(1), + dq.stride(2), + ) + stride_dkb, stride_dkn, stride_dkh, stride_dkd = ( + 0, + dk.stride(0), + dk.stride(1), + dk.stride(2), + ) + stride_dvb, stride_dvn, stride_dvh, stride_dvd = ( + 0, + dv.stride(0), + dv.stride(1), + dv.stride(2), + ) + stride_dob, stride_dom, stride_doh, stride_dod = ( + 0, + do.stride(0), + do.stride(1), + do.stride(2), + ) + stride_lse_b, stride_lse_h, stride_lse_m = ( + 0, + softmax_lse.stride(0), + softmax_lse.stride(1), + ) + else: + # shapes + batch_q, seqlen_q, nheads_q, head_size_q = q.shape + batch_k, seqlen_k, nheads_k, head_size_k = k.shape + batch_v, seqlen_v, nheads_v, head_size_v = v.shape + batch_lse, nheads_lse, seqlen_lse = softmax_lse.shape + + # assert batch dimensions + assert ( + batch_q == batch_k == batch_v + ), f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" + + # assert head dimensions + assert ( + head_size_q == head_size_k + ), f"head sizes must match: q={head_size_q}, k={head_size_k}" + assert ( + nheads_k == nheads_v + ), f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert ( + nheads_q % nheads_k == 0 + ), f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + + # assert sequence lengths + assert ( + seqlen_k == seqlen_v + ), f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" + + # assert output shapes + assert o.shape == ( + batch_q, + seqlen_q, + nheads_q, + head_size_v, + ), f"o shape {o.shape} != expected" + assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" + assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}" + + # assert softmax_lse shape + assert softmax_lse.shape == ( + batch_q, + nheads_q, + seqlen_q, + ), f"softmax_lse shape {softmax_lse.shape} != expected" + + # set vars + batch = batch_q + head_size_qk = head_size_q + max_seqlen_q = seqlen_q + max_seqlen_k = seqlen_k + + # strides + stride_qb, stride_qm, stride_qh, stride_qd = q.stride() + stride_kb, stride_kn, stride_kh, stride_kd = k.stride() + stride_vb, stride_vn, stride_vh, stride_vd = v.stride() + stride_ob, stride_om, stride_oh, stride_od = o.stride() + stride_dqb, stride_dqm, stride_dqh, stride_dqd = dq.stride() + stride_dkb, stride_dkn, stride_dkh, stride_dkd = dk.stride() + stride_dvb, stride_dvn, stride_dvh, stride_dvd = dv.stride() + stride_dob, stride_dom, stride_doh, stride_dod = do.stride() + stride_lse_b, stride_lse_h, stride_lse_m = softmax_lse.stride() + + # fp8 setup - moved after all assertions + IS_FP8 = is_fp8(q) + if IS_FP8: + FP8_MAX = torch.finfo(q.dtype).max + # we already asserted that do, q, k, v all have the same dtype, so no need to check each one + if is_fp8(o): + FP8_OUTPUT = True + assert ( + descale_o is not None + ), f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." + assert ( + descale_dq is not None + ), f"descale_dq is None. In fp8, you need to pass a tensor for descale_dq along with a tensor dq." + assert ( + descale_dk is not None + ), f"descale_dk is None. In fp8, you need to pass a tensor for descale_dk along with a tensor dk." + assert ( + descale_dv is not None + ), f"descale_dv is None. In fp8, you need to pass a tensor for descale_dv along with a tensor dv." + else: + FP8_OUTPUT = False + + stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None + stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None + stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None + stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None + stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None + + if DEBUG: + print(f"FP8 path triggered (FP8_OUTPUT={FP8_OUTPUT})") + else: + FP8_MAX = None + FP8_OUTPUT = False + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( + stride_descale_o_z + ) = stride_descale_do_z = None + + # alibi setup + use_alibi, (stride_az, stride_ah) = ( + (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) + ) + + # get closest power of 2 over or equal to 32. + padded_d_model_qk = 1 << (head_size_qk - 1).bit_length() + padded_d_model_qk = max(padded_d_model_qk, 32) + padded_d_model_v = 1 << (head_size_v - 1).bit_length() + padded_d_model_v = max(padded_d_model_v, 32) + HEAD_DIM_QK = padded_d_model_qk + HEAD_DIM_V = padded_d_model_v + ACTUAL_HEAD_DIM_QK = head_size_qk + ACTUAL_HEAD_DIM_V = head_size_v + + # init delta + if IS_VARLEN: + # Shape expected by interface varlen backward: (Hq, Total_Q) + total_q, _, _ = q.shape + delta = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + stride_delta_b, stride_delta_h, stride_delta_m = ( + 0, + delta.stride(0), + delta.stride(1), + ) + else: + # Shape expected by dense backward: (B, Hq, Sq) + seqlen_q = q.shape[1] + delta = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + stride_delta_b, stride_delta_h, stride_delta_m = delta.stride() + + pre_grid = lambda META: ( + triton.cdiv(max_seqlen_q, META["PRE_BLOCK"]), + batch, + nheads_q, + ) + _bwd_preprocess[pre_grid]( + o, + do, + delta, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_descale_do_z, + cu_seqlens_q, + max_seqlen_q, + descale_do, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + ) + + if False: + print("delta:", delta, delta.shape) + + # dropout mask tensor for debugging. We dump the dropout mask created in + # the kernel for testing + dropout_mask = None + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = (0, 0, 0, 0) + if use_dropout: + dropout_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32, + ) + + if DROPOUT_USE_PYTORCH: + if not IS_VARLEN: + dropout_mask = create_dropout_mask( + dropout_p, + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + seed=philox_seed, + ) + else: + dropout_mask = create_dropout_mask_varlen( + dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed + ) + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = ( + dropout_mask.stride() + ) + + seqlen = max(max_seqlen_q, max_seqlen_k) + grid = lambda META: ( + nheads_k, + (seqlen + META["BLOCK_N1"] - 1) // META["BLOCK_N1"], + batch, + ) + if causal: + if DEBUG_TRITON: + print(f"bwd_kernel: grid = {grid}") # noqa: E701 + bwd_kernel_causal[grid]( + q, + k, + v, + sm_scale, + do, + dq, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_descale_do_z, + stride_az, + stride_ah, + nheads_q, + nheads_k, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Pass seqused tensors + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, + descale_q, + descale_k, + descale_v, + descale_do, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, + USE_SEQUSED=( + seqused_q is not None or seqused_k is not None + ), # Add flag for seqused + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + else: + bwd_kernel_noncausal[grid]( + q, + k, + v, + sm_scale, + do, + dq, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_descale_do_z, + stride_az, + stride_ah, + nheads_q, + nheads_k, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Pass seqused tensors + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, + descale_q, + descale_k, + descale_v, + descale_do, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, + USE_SEQUSED=( + seqused_q is not None or seqused_k is not None + ), # Add flag for seqused + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + return delta + + +def attention_backward_triton_fused_atomics_impl( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + philox_seed: Optional[int] = 0, + philox_offset: Optional[int] = 0, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None, + fused: bool = False, + # seqused for FA v3 (currently ignored in this implementation) + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, +): + IS_FP8 = is_fp8(q) + if IS_FP8: + FP8_MAX = torch.finfo(q.dtype).max + descale_strides = ( + descale_q.stride(0), + descale_k.stride(0), + descale_v.stride(0), + descale_do.stride(0), + ) + + if DEBUG: + print(f"FP8 path triggered") + else: + FP8_MAX = None + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( + stride_descale_do_z + ) = None + descale_strides = ( + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_descale_do_z, + ) + + IS_VARLEN = True if cu_seqlens_q is not None else False + + # get strides and shape + if IS_VARLEN: + # Layout for q,k,v is thd ie [total tokens, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = ( + len(cu_seqlens_q) - 1, + max_seqlen_q, + q.shape[1], + q.shape[2], + ) + seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + dq_strides = (0, dq.stride(1), dq.stride(0), dq.stride(2)) + dk_strides = (0, dk.stride(1), dk.stride(0), dk.stride(2)) + dv_strides = (0, dv.stride(1), dv.stride(0), dv.stride(2)) + do_strides = (0, do.stride(1), do.stride(0), do.stride(2)) + else: + # Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = q.shape + seqlen_k, num_k_heads = k.shape[1], k.shape[2] + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + dq_strides = (dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3)) + dk_strides = (dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3)) + dv_strides = (dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3)) + do_strides = (do.stride(0), do.stride(2), do.stride(1), do.stride(3)) + + # BLOCK_D_MODEL, BLOCK_D_MODEL_POW2 + # padding for head_dim. Power of 2 or 16 + BLOCK_D_MODEL_POW2 = triton.next_power_of_2(head_sz) + BLOCK_D_MODEL_POW2 = max(BLOCK_D_MODEL_POW2, 16) + + # Configs + # PRE_BLOCK, BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 + # BLK_SLICE_FACTOR + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + PRE_BLOCK = 128 + # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 + BLK_SLICE_FACTOR = 2 + + # init delta + delta = torch.zeros_like(softmax_lse) + if IS_VARLEN: + # [total_tokens, num_q_heads, seqlen_q] + delta_strides = (0, delta.stride(1), delta.stride(0)) + else: + # [batch, num_q_heads, seqlen_q] + delta_strides = delta.stride() + + # preprocess + # compute D(delta) = rowsum(dO*O). Note, multiplication is element-wise. + pre_grid = (triton.cdiv(max_seqlen_q, PRE_BLOCK), batch, num_q_heads) + _bwd_fused_atomics_preprocess[pre_grid]( + o, + do, + delta, + *o_strides, + *delta_strides, + descale_strides[3], + cu_seqlens_q, + max_seqlen_q, + descale_do, + BLOCK_M=PRE_BLOCK, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + ) + + # dropout_mask + use_dropout = dropout_p > 0.0 + if use_dropout: + dropout_mask = torch.zeros( + (batch, num_q_heads, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32, + ) + dropout_strides = dropout_mask.stride() + else: + dropout_mask = None + dropout_strides = (0, 0, 0, 0) + + grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, num_k_heads) + grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, num_k_heads) + + if ( + fused + ): # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups + + BLOCK_N = ( + 128 if BLOCK_D_MODEL_POW2 < 160 else 64 + ) # larger head sizes lead to oom + config = { + "BLOCK_M": 32, + "BLOCK_N": BLOCK_N, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 1, + "BLK_SLICE_FACTOR": 2, + } + + num_k_pids = (max_seqlen_k + BLOCK_N - 1) // BLOCK_N + grid_dkdvdq = (batch * num_k_heads * num_k_pids,) + + if causal: + _bwd_kernel_fused_atomics_dkdvdq_causal[grid_dkdvdq]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + dq, + softmax_lse, + delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *dq_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + **config, + ) + else: + _bwd_kernel_fused_atomics_dkdvdq_noncausal[grid_dkdvdq]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + dq, + softmax_lse, + delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *dq_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + **config, + ) + + return delta + + # split kernels solution: one kernel computes dk, dv and the other computes dq + + if causal: + _bwd_kernel_fused_atomics_dkdv_causal[grid_dkdv]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + softmax_lse, + delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + _bwd_kernel_fused_atomics_dq_causal[grid_dq]( + q, + k, + v, + sm_scale, + do, + dq, + softmax_lse, + delta, + *q_strides, + *k_strides, + *v_strides, + *dq_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + else: + _bwd_kernel_fused_atomics_dkdv_noncausal[grid_dkdv]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + softmax_lse, + delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + + _bwd_kernel_fused_atomics_dq_noncausal[grid_dq]( + q, + k, + v, + sm_scale, + do, + dq, + softmax_lse, + delta, + *q_strides, + *k_strides, + *v_strides, + *dq_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + + return delta + + +def attention_backward_triton_impl( + *, + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: str, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + philox_seed: Optional[int] = None, + philox_offset: Optional[int] = None, + use_exp2: bool = True, + mode: str = "fused_no_atomics", +) -> torch.Tensor: + """Unified backward interface dispatching to atomics or no-atomics implementation. + + Parameters mirror the superset of the two legacy interfaces. The public API should + call ONLY this function going forward. + mode: 'fused_atomics' or 'fused_no_atomics'; layout: 'bshd' or 'thd'; use_exp2 retained for parity. + """ + # Enforce supported dtypes (mirror Hopper behavior: FP8 forward-only) + supported_dtypes = {torch.float16, torch.bfloat16, torch.float32} + for name, t in {"q": q, "k": k, "v": v, "o": o, "do": do}.items(): + if t.dtype not in supported_dtypes: + raise TypeError( + f"Backward only supports fp16/bf16/fp32; tensor '{name}' has dtype {t.dtype}" + ) + + if mode == "fused_atomics": + # Atomics path ignores layout & use_exp2; pass varlen metadata directly. + return attention_backward_triton_fused_atomics_impl( + do, + q, + k, + v, + o, + softmax_lse, + dq, + dk, + dv, + sm_scale, + alibi_slopes, + causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q if max_seqlen_q is not None else q.shape[1], + max_seqlen_k if max_seqlen_k is not None else k.shape[1], + dropout_p, + philox_seed or 0, + philox_offset or 0, + None, + None, + None, + None, + True, # fused flag + None, + None, + ) + elif mode == "fused_no_atomics": + return attention_backward_triton_split_fused_no_atomics_impl( + do, + q, + k, + v, + o, + softmax_lse, + dq, + dk, + dv, + sm_scale, + alibi_slopes, + causal, + layout, # layout required here + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + use_exp2, + None, + None, + None, + None, + None, + None, + None, + None, + seqused_q, + seqused_k, + ) + else: + raise ValueError( + f"Unknown backward mode '{mode}'. Expected 'fused_atomics' or 'fused_no_atomics'." + ) diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py deleted file mode 100755 index 51e53daedc2..00000000000 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py +++ /dev/null @@ -1,1815 +0,0 @@ -import torch -import triton -import triton.language as tl -from flash_attn.flash_attn_triton_amd.utils import compute_fp8_scaling_factors, DEBUG, is_fp8 - -from typing import Optional, Tuple - -# This function computes delta given output Out and gradient DO -# Here is the I/O shape: -# Out: (batch, nhead_q, max_seqlens_q, headDim) -# DO: (batch, nhead_q, max_seqlens_q, headDim) -# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at -@triton.jit -def _bwd_preprocess( - o_ptr, do_ptr, # noqa: E741 - delta_ptr, - stride_o_b, stride_o_h, stride_o_m, stride_o_k, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_descale_do_z, - cu_seqlens_q, max_seqlen_q, - descale_do_ptr, - BLOCK_M: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr -): - pid_m = tl.program_id(0) #seqlen - bid = tl.program_id(1) #batch - hid = tl.program_id(2) #head - - # Handle varlen - q_start = 0 - seqlen_q = max_seqlen_q - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - seqlen_q = q_end - q_start - else: - q_start = 0 - seqlen_q = max_seqlen_q - - # Compute offsets - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - - # Offset O/DO by batch, head and q_start - offs = (bid * stride_o_b + - hid * stride_o_h + - q_start * stride_o_m + offs_m[:, None] * stride_o_m + - offs_k[None, :] * stride_o_k) - - # create masks - mask_m = offs_m < seqlen_q - mask = mask_m[:, None] - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask &= offs_k[None, :] < BLOCK_D_MODEL - - # load [BLOCK_M, BLOCK_D_MODEL_POW2] - o = tl.load(o_ptr + offs, mask=mask, other=0.0) - do = tl.load(do_ptr + offs, mask=mask, other=0.0) - - # compute and write-back to delta - if IS_FP8: - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hid) - - # NOTE: do is in the fp8 range and o is not in fp8 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - - offs_delta = (bid * stride_delta_b + - hid * stride_delta_h + - q_start * stride_delta_m + offs_m * stride_delta_m) - tl.store(delta_ptr + offs_delta, delta, mask=mask_m) - -@triton.jit -def _bwd_dq_inner( - dq, - q, K, V, do, m, Delta, sm_scale, - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropout_m, stride_dropout_n, - stride_deltam, - seqlen_q, seqlen_k, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - MASK: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - RCP_LN2: tl.constexpr = 1.4426950408889634 - - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M) - offs_n = start_n + tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - - # mask to make sure not OOB of seqlen_q - mask_m = offs_m < seqlen_q - - kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk - vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk - - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) - - curr_n = start_n - step_n = BLOCK_N - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - for blk_idx in range(num_steps): - offs_n = curr_n + tl.arange(0, BLOCK_N) - # end_n is needed because the end of causal True might not be perfectly - # aligned with the end of the block - mask_n = offs_n < end_n - mask_kT = mask_n[None, :] - mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) - if PADDED_HEAD: - mask_kT &= offs_k[:, None] < BLOCK_D_MODEL - - kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) - vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) - - #dropout - if ENABLE_DROPOUT: - philox_offs = (curr_philox_offset + - offs_m[:, None] * stride_dropout_m + - offs_n[None, :] * stride_dropout_n) - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1 / (1 - dropout_p) - - #qk - if IS_FP8: - qk = tl.dot(q, kT) * descale_q * descale_k - else: - qk = tl.dot(q, kT) - p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) - - if MASK: - causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] - mask = causal_mask * mask_mn - p = tl.where(mask, p, 0.0) - - #dp - if IS_FP8: - dp = (tl.dot(do, vT) * descale_do * descale_v) - else: - dp = tl.dot(do, vT) - - if ENABLE_DROPOUT: - dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale - - #ds - delta_i = Di[:, None] - ds = p * (dp - delta_i) - - #dq - # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. - if IS_FP8: - scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) - dq += (tl.dot((ds*scale_ds).to(kT.type.element_ty), tl.trans(kT)) * descale_ds * descale_k) - else: - dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) - - curr_n += step_n - kT_ptrs += step_n * stride_kn - vT_ptrs += step_n * stride_vn - return dq - - -@triton.jit -def _bwd_dkdv_inner( - dk, dv, - Q, k, v, DO, M, D, sm_scale, - stride_q_m, stride_q_k, - stride_do_m, stride_do_k, - stride_dropout_m, stride_dropout_n, - stride_deltam, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - seqlen_q, seqlen_k, - start_n, start_m, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - MASK: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M) - offs_n = start_n + tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - - # mask to make sure not OOB of seqlen_q - mask_n = offs_n < seqlen_k - qT_ptrs = Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k #[BLOCK_D_MODEL_POW2, BLOCK_M] - do_ptrs = DO + offs_m[:, None] * stride_do_m + offs_k[None,: ] * stride_do_k - curr_m = start_m - step_m = BLOCK_M - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 - - #Iterate over blocks(BLOCK_M size) of Q while calculating - #a fixed block(BLOCK_N) of dk and dv. Note, during backward - #pass P has to be recomputed. However, this kernel computes - #dV and dK, so we compute we need P^T and S^T. See backward pass - #equations - # - #From Flash Attention Paper: - #ForwardPass: S = QkT, P=softmax(S), O=PV - # - #BackwardPass equations - #dV = P^TdO - #dP = dOV^T - #dS = dsoftmax(dP) - #dQ = dSK - #dK = QdS^T - for blk_idx in range(num_steps): - offs_m = curr_m + tl.arange(0, BLOCK_M) - mask_m = offs_m < seqlen_q - mask_qT = mask_m[None, :] - mask_do = mask_m[:, None] - mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) - if PADDED_HEAD: - mask_qT &= offs_k[:, None] < BLOCK_D_MODEL - mask_do &= offs_k[None, :] < BLOCK_D_MODEL - - #load qT - qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) - - #dropout - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = (curr_philox_offset + - offs_m[None, :] * stride_dropout_m + - offs_n[:, None] * stride_dropout_n) - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1.0 / (1 - dropout_p) - - #Load M - m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) - - #Compute qkT - if IS_FP8: - qkT = (tl.dot(k, qT) * descale_q * descale_k) - else: - qkT = tl.dot(k, qT) - - #Compute pT(use m and also apply sm_scale) - pT = tl.math.exp(qkT * sm_scale - m[None, :]) - - if MASK: - causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] - mask = causal_mask & mask_nm - pT = tl.where(mask, pT, 0.0) - - #load DO - do = tl.load(do_ptrs, mask=mask_do, other=0.0) - - #dV - if ENABLE_DROPOUT: - pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) - dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) * descale_p_dropout * descale_do) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) - else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) - - #Load delta - Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) - - #Compute dP and dS - if IS_FP8: - dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do - else: - dpT = tl.dot(v, tl.trans(do)) - - if ENABLE_DROPOUT: - dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale - - delta_i = Di[None, :] - dsT = pT * (dpT - delta_i) - - #compute dk - if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) - else: - dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) - - #increment pointers - curr_m += step_m - qT_ptrs += step_m * stride_q_m - do_ptrs += step_m * stride_do_m - - return dk, dv - - -@triton.jit -def _bwd_dkdvdq_inner( - dk, dv, - Q, k, v, DO, DQ, M, D, sm_scale, - stride_q_m, stride_q_k, - stride_dq_m, stride_dq_k, - stride_do_m, stride_do_k, - stride_dropout_m, stride_dropout_n, - stride_deltam, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - seqlen_q, seqlen_k, - start_n, start_m, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - MASK: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - workgroup_id: tl.int32, -): - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M) - offs_n = start_n + tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - - # mask to make sure not OOB of seqlen_q - mask_n = offs_n < seqlen_k - - qT_ptrs_start = Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k #[BLOCK_D_MODEL_POW2, BLOCK_M] - dq_ptrs_start = DQ + offs_m[:, None] * stride_dq_m + offs_k[None,:] * stride_dq_k #[BLOCK_M, BLOCK_D_MODEL_POW2] - - do_ptrs_start = DO + offs_m[:, None] * stride_do_m + offs_k[None,: ] * stride_do_k - curr_m = start_m - step_m = BLOCK_M - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 - - #Iterate over blocks(BLOCK_M size) of Q while calculating - #a fixed block(BLOCK_N) of dk and dv. Note, during backward - #pass P has to be recomputed. However, this kernel computes - #dV and dK, so we compute we need P^T and S^T. See backward pass - #equations - # - #From Flash Attention Paper: - #ForwardPass: S = QkT, P=softmax(S), O=PV - # - #BackwardPass equations - #dV = P^TdO - #dP = dOV^T - #dS = dsoftmax(dP) - #dQ = dSK - #dK = QdS^T - - # Compute a starting index and step based on workgroup_id - # Use a simple hash-like function to spread out the starting points - start_idx = (workgroup_id * 17) % num_steps # 17 is an arbitrary prime to spread indices - # Ensure step is coprime with num_steps to visit all indices exactly once - step = 1 # 3 if num_steps > 1 or num_steps==3 else 1 # coprime with num_steps - - - for iter in range(num_steps): - # Compute the permuted block index - blk_idx = (start_idx + iter * step) % num_steps - - curr_m = start_m + blk_idx * step_m - qT_ptrs = qT_ptrs_start + blk_idx * step_m * stride_q_m - dq_ptrs = dq_ptrs_start + blk_idx * step_m * stride_dq_m - do_ptrs = do_ptrs_start + blk_idx * step_m * stride_do_m - - offs_m = curr_m + tl.arange(0, BLOCK_M) - mask_m = offs_m < seqlen_q - mask_qT = mask_m[None, :] - mask_do = mask_m[:, None] - mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) - - if PADDED_HEAD: - mask_qT &= offs_k[:, None] < BLOCK_D_MODEL - mask_do &= offs_k[None, :] < BLOCK_D_MODEL - - #load qT - qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) - - #dropout - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = (curr_philox_offset + - offs_m[None, :] * stride_dropout_m + - offs_n[:, None] * stride_dropout_n) - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1.0 / (1 - dropout_p) - - #Load M - m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) - - #Compute qkT - if IS_FP8: - qkT = (tl.dot(k, qT) * descale_q * descale_k) - else: - qkT = tl.dot(k, qT) - - #Compute pT(use m and also apply sm_scale) - pT = tl.math.exp(qkT * sm_scale - m[None, :]) - - if MASK: - causal_mask = (offs_m[None, :] - delta_qk) >= (offs_n[:, None]) - mask = causal_mask & mask_nm - pT = tl.where(mask, pT, 0.0) - - #load DO - do = tl.load(do_ptrs, mask=mask_do, other=0.0) - - #dV - if ENABLE_DROPOUT: - pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) - dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) * descale_p_dropout * descale_do) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) - else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) - - #Load delta - Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) - - #Compute dP and dS - if IS_FP8: - dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do - else: - dpT = tl.dot(v, tl.trans(do)) - - if ENABLE_DROPOUT: - dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale - - delta_i = Di[None, :] - dsT = pT * (dpT - delta_i) - - #compute dk - if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) - else: - dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) - - - # We can compute the dq_partial here and do a atomic add to the correct memory location - # NOTE: Possible problems with the atomic add: contention, is inside a loop which has achieved bad perf before - # (BLOCK_M, BLOCK_N) x (BLOCK_N, D) - if IS_FP8: - dq_partial = tl.dot((dsT * scale_dsT).to(k.dtype).T, k) * descale_dsT * descale_k - else: - dq_partial = tl.dot(dsT.to(k.dtype).T, k) - tl.atomic_add( - dq_ptrs, - dq_partial * sm_scale, - mask=mask_m[:, None], - sem="relaxed", - ) - - return dk, dv - - -@triton.jit -def _bwd_kernel_dkdvdq_causal( - q_ptr, k_ptr, v_ptr, sm_scale, do_ptr, dk_ptr, dv_ptr, dq_ptr, - m_ptr, delta_ptr, - stride_q_b, stride_q_h, stride_q_m, stride_q_k, - stride_k_b, stride_k_h, stride_k_n, stride_k_k, - stride_v_b, stride_v_h, stride_v_n, stride_v_k, - stride_dk_b, stride_dk_h, stride_dk_n, stride_dk_k, - stride_dq_b, stride_dq_h, stride_dq_m, stride_dq_k, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_do_b, stride_do_h, stride_do_m, stride_do_k, - stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, - descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BATCH, - NUM_K_PIDS, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - wid = tl.program_id(0) # workgoup id: 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 - - # workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim - batch_idx = wid % BATCH - head_k_idx = wid // BATCH % NUM_K_HEADS - seq_k_blk_idx = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS - - #Determine q and k start along with seqlen_q and seqlen_k - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + batch_idx) - q_end = tl.load(cu_seqlens_q + batch_idx + 1) - k_start = tl.load(cu_seqlens_k + batch_idx) - k_end = tl.load(cu_seqlens_k + batch_idx + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - - # Figure out causal starting block since we have seqlen_q >=< seqlen_k. - # Unlike forward pass where we tile on M dim and iterate on N dim, so that - # we can skip some M blocks, in backward pass, we tile on the N dim for kv - # and iterate over the M. In this way, we cannot skip N blocks, but only to - # determine the starting M blocks to skip some initial blocks masked by - # causal. - delta_qk = seqlen_q - seqlen_k - - # q > k: diretcly skip all the way until the start of causal block - start_delta_q_gt_k = delta_qk - - # q < k: some blocks will have no Masked block, other needs to re-calc - # starting position - # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the - # masked op - num_blocks_skip = -delta_qk // BLOCK_N - delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk - start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M - if delta_qk >= 0: - start_delta = delta_qk - else: - start_delta = start_delta_q_lt_k - - start_n = seq_k_blk_idx * BLOCK_N - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_n = start_n + tl.arange(0, BLOCK_N) - # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask_k = offs_k < BLOCK_D_MODEL - mask_kv &= mask_k[None, :] - - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - adj_k = (batch_idx * stride_k_b + - head_k_idx * stride_k_h + - k_start * stride_k_n + offs_n[:, None] * stride_k_n + - offs_k[None, :] * stride_k_k) - adj_v = (batch_idx * stride_v_b + - head_k_idx * stride_v_h + - k_start * stride_v_n + offs_n[:, None] * stride_v_n + - offs_k[None, :] * stride_v_k) - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(k_ptr + adj_k , mask=mask_kv, other=0.0) - v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) - - # If MQA / GQA, set the K and V head offsets appropriately. - for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): - if delta_qk >= 0: - start_m = start_n + start_delta - len_m = BLOCK_N - else: - start_m = max(start_n + delta_qk, 0) - start_m = (start_m // BLOCK_M) * BLOCK_M - # because we might shift the masked blocks up, we are deeper into - # the masked out region, so we would potentially increase the total - # steps with masked operation to get out of it - residue_m = max(start_n + delta_qk - start_m, 0) - len_m = BLOCK_N + residue_m - - # offset input and output tensor by batch and Q/K heads - adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m - adj_dq = batch_idx * stride_dq_b + head_q_idx * stride_dq_h + q_start * stride_dq_m - - q_ptr_adj = q_ptr + adj_q - dq_ptr_adj = dq_ptr + adj_dq - - adj_do = batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m - do_ptr_adj = do_ptr + adj_do - adj_delta = batch_idx * stride_delta_b + head_q_idx * stride_delta_h + q_start * stride_delta_m - m_ptr_adj = m_ptr + adj_delta - delta_ptr_adj = delta_ptr + adj_delta - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = (philox_offset_base + batch_idx * stride_dropout_b + - head_q_idx * stride_dropout_h) - dropout_offset = (dropout_mask + batch_idx * stride_dropout_b + - head_q_idx * stride_dropout_h) - - MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR - # bound the masked operation to q len so it does not have to wast cycles - len_m = min(len_m, seqlen_q) - num_steps = tl.cdiv(len_m, MASK_BLOCK_M) - - - # when q < k, we may skip the initial masked op - # if seq_k_blk_idx < num_blocks_skip: - # num_steps = 0 - - if IS_FP8: - descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) - descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) - descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) - descale_do = tl.load(descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # if unaligned start_m is negative, the current N-tile has no block on the - # diagonal of causal mask, so everything have no causal mask - dk, dv = _bwd_dkdvdq_inner( - dk, dv, # output tensors - q_ptr_adj, k, v, do_ptr_adj, dq_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors - stride_q_m, stride_q_k, # strides for q - stride_dq_m, stride_dq_k, # strides for q - stride_do_m, stride_do_k, # strides for o - stride_dropout_m, stride_dropout_n, # strides for dropout - stride_delta_m, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK_BLOCK_M, BLOCK_N, # block dim - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim - MASK=True, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - workgroup_id=seq_k_blk_idx, - ) - - - start_m += num_steps * MASK_BLOCK_M - num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) - end_m = start_m + num_steps * BLOCK_M - - - - dk, dv = _bwd_dkdvdq_inner( - dk, dv, # output tensors - q_ptr_adj, k, v, do_ptr_adj, dq_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors - stride_q_m, stride_q_k, # strides for q - stride_dq_m, stride_dq_k, # strides for dq - stride_do_m, stride_do_k, # strides for o - stride_dropout_m, stride_dropout_n, # strides for dropout - stride_delta_m, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - BLOCK_M, BLOCK_N, # block dim - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - workgroup_id=seq_k_blk_idx, - ) - - # Write back dV and dK. - offs_dkdv = (batch_idx * stride_dk_b + - head_k_idx * stride_dk_h + - k_start * stride_dk_n + offs_n[:, None] * stride_dk_n + - offs_k[None, :] * stride_dk_k) - tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) - - -@triton.jit -def _bwd_kernel_dkdv_causal( - q_ptr, k_ptr, v_ptr, sm_scale, do_ptr, dk_ptr, dv_ptr, - m_ptr, delta_ptr, - stride_q_b, stride_q_h, stride_q_m, stride_q_k, - stride_k_b, stride_k_h, stride_k_n, stride_k_k, - stride_v_b, stride_v_h, stride_v_n, stride_v_k, - stride_dk_b, stride_dk_h, stride_dk_n, stride_dk_k, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_do_b, stride_do_h, stride_do_m, stride_do_k, - stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, - descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - #seq block, batch, head_k - seq_k_blk_idx = tl.program_id(0) - batch_idx = tl.program_id(1) - head_k_idx = tl.program_id(2) - - #Determine q and k start along with seqlen_q and seqlen_k - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + batch_idx) - q_end = tl.load(cu_seqlens_q + batch_idx + 1) - k_start = tl.load(cu_seqlens_k + batch_idx) - k_end = tl.load(cu_seqlens_k + batch_idx + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - - # Figure out causal starting block since we have seqlen_q >=< seqlen_k. - # Unlike forward pass where we tile on M dim and iterate on N dim, so that - # we can skip some M blocks, in backward pass, we tile on the N dim for kv - # and iterate over the M. In this way, we cannot skip N blocks, but only to - # determine the starting M blocks to skip some initial blocks masked by - # causal. - delta_qk = seqlen_q - seqlen_k - - # q > k: diretcly skip all the way until the start of causal block - start_delta_q_gt_k = delta_qk - - # q < k: some blocks will have no Masked block, other needs to re-calc - # starting position - # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the - # masked op - num_blocks_skip = -delta_qk // BLOCK_N - delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk - start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M - if delta_qk >= 0: - start_delta = delta_qk - else: - start_delta = start_delta_q_lt_k - - start_n = seq_k_blk_idx *BLOCK_N - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_n = start_n + tl.arange(0, BLOCK_N) - # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask_k = offs_k < BLOCK_D_MODEL - mask_kv &= mask_k[None, :] - - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - adj_k = (batch_idx * stride_k_b + - head_k_idx * stride_k_h + - k_start * stride_k_n + offs_n[:, None] * stride_k_n + - offs_k[None, :] * stride_k_k) - adj_v = (batch_idx * stride_v_b + - head_k_idx * stride_v_h + - k_start * stride_v_n + offs_n[:, None] * stride_v_n + - offs_k[None, :] * stride_v_k) - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(k_ptr + adj_k , mask=mask_kv, other=0.0) - v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) - - # If MQA / GQA, set the K and V head offsets appropriately. - for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): - if delta_qk >= 0: - start_m = start_n + start_delta - len_m = BLOCK_N - else: - start_m = max(start_n + delta_qk, 0) - start_m = start_m // BLOCK_M * BLOCK_M - # because we might shift the masked blocks up, we are deeper into - # the masked out region, so we would potentially increase the total - # steps with masked operation to get out of it - residue_m = max(start_n + delta_qk - start_m, 0) - len_m = BLOCK_N + residue_m - - # offset input and output tensor by batch and Q/K heads - adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m - q_ptr_adj = q_ptr + adj_q - adj_do = batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m - do_ptr_adj = do_ptr + adj_do - adj_delta = batch_idx * stride_delta_b + head_q_idx * stride_delta_h + q_start * stride_delta_m - m_ptr_adj = m_ptr + adj_delta - delta_ptr_adj = delta_ptr + adj_delta - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = (philox_offset_base + batch_idx * stride_dropout_b + - head_q_idx * stride_dropout_h) - dropout_offset = (dropout_mask + batch_idx * stride_dropout_b + - head_q_idx * stride_dropout_h) - - MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR - # bound the masked operation to q len so it does not have to wast cycles - len_m = min(len_m, seqlen_q) - num_steps = tl.cdiv(len_m, MASK_BLOCK_M) - # when q < k, we may skip the initial masked op - if seq_k_blk_idx < num_blocks_skip: - num_steps = 0 - - if IS_FP8: - descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) - descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) - descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) - descale_do = tl.load(descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # if start_m is negative, the current N-tile has no block on the - # diagonal of causal mask, so everything have no causal mask - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - q_ptr_adj, k, v, do_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors - stride_q_m, stride_q_k, # strides for q - stride_do_m, stride_do_k, # strides for o - stride_dropout_m, stride_dropout_n, # strides for dropout - stride_delta_m, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK_BLOCK_M, BLOCK_N, # block dim - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim - MASK=True, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - start_m += num_steps * MASK_BLOCK_M - num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) - end_m = start_m + num_steps * BLOCK_M - - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - q_ptr_adj, k, v, do_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors - stride_q_m, stride_q_k, # strides for q - stride_do_m, stride_do_k, # strides for o - stride_dropout_m, stride_dropout_n, # strides for dropout - stride_delta_m, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - BLOCK_M, BLOCK_N, # block dim - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - - # Write back dV and dK. - offs_dkdv = (batch_idx * stride_dk_b + - head_k_idx * stride_dk_h + - k_start * stride_dk_n + offs_n[:, None] * stride_dk_n + - offs_k[None, :] * stride_dk_k) - tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) - -@triton.jit -def _bwd_kernel_dq_causal( - q_ptr, k_ptr, v_ptr, sm_scale, do_ptr, dq_ptr, - m_ptr, delta_ptr, - stride_q_b, stride_q_h, stride_q_m, stride_q_k, - stride_k_b, stride_k_h, stride_k_n, stride_k_k, - stride_v_b, stride_v_h, stride_v_n, stride_v_k, - stride_dq_b, stride_dq_h, stride_dq_m, stride_dq_k, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_do_b, stride_do_h, stride_do_m, stride_do_k, - stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, - descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - seq_q_blk_idx = tl.program_id(0) - batch_idx = tl.program_id(1) - head_k_idx = tl.program_id(2) - - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + batch_idx) - q_end = tl.load(cu_seqlens_q + batch_idx + 1) - k_start = tl.load(cu_seqlens_k + batch_idx) - k_end = tl.load(cu_seqlens_k + batch_idx + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - # Figure out causal starting block since we have seqlen_q <=> seqlen_k. - # Unlike forward pass where we tile on M dim and iterate on N dim, so that - # we can skip some M blocks, in backward pass, we tile on the N dim for kv - # and iterate over the M. In this way, we cannot skip N blocks, but only to - # determine the starting M blocks to skip some initial blocks masked by - # causal. - # DQ tiles on M dim and iterate on N dim, so we there could be some tiles we - # can simply skip and we need to adjust starting position. - start_m = seq_q_blk_idx * BLOCK_M - # seqlen_q > seqlen_k, no need to process these tile for dq - delta_qk = seqlen_q - seqlen_k - if start_m + BLOCK_M < delta_qk: - return - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_m = start_m + tl.arange(0, BLOCK_M) - # Mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask_k = offs_k < BLOCK_D_MODEL - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_q_m + offs_k[None, :] * stride_q_k - offs_do = offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k - adj_k = batch_idx * stride_k_b + head_k_idx * stride_k_h + k_start * stride_k_n - adj_v = batch_idx * stride_v_b + head_k_idx * stride_v_h + k_start * stride_v_n - k_ptr_adj = k_ptr - v_ptr_adj = v_ptr - k_ptr_adj += adj_k - v_ptr_adj += adj_v - - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): - # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front - # for every M-tile - end_n = start_m + BLOCK_M - delta_qk - # clamp end_n at [0, seqlen_k] - end_n = max(min(end_n, seqlen_k), 0) - - # offset input and output tensor by batch and Q/K heads - adj_q = (batch_idx * stride_q_b + - head_q_idx * stride_q_h + - q_start * stride_q_m) - adj_do = (batch_idx * stride_do_b + - head_q_idx * stride_do_h + - q_start * stride_do_m) - adj_delta = (batch_idx * stride_delta_b + - head_q_idx * stride_delta_h + - q_start * stride_delta_m) - delta_ptr_adj = delta_ptr + adj_delta - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = (philox_offset_base + - batch_idx * stride_dropout_b + - head_q_idx * stride_dropout_h) - dropout_offset = (dropout_mask + - batch_idx * stride_dropout_b + - head_q_idx * stride_dropout_h) - - q = tl.load(q_ptr + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(do_ptr + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(m_ptr + adj_delta + offs_m * stride_delta_m, - mask=offs_m < seqlen_q) - m = m[:, None] - - MASK_BLOCK_N: tl.constexpr = BLOCK_N // BLK_SLICE_FACTOR - # start can only be 0 at minimum - start_n = max(end_n - BLOCK_M, 0) - num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N) - - if IS_FP8: - descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) - descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) - descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) - descale_do = tl.load(descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) - # Compute dQ for masked (diagonal) blocks. - # NOTE: This code scans each row of QK^T backward (from right to left, - # but inside each call to _bwd_dq_inner, from left to right), but that's - # not due to anything important. I just wanted to reuse the loop - # structure for dK & dV above as much as possible. - dq = _bwd_dq_inner( - dq, - q, k_ptr_adj, v_ptr_adj, do, m, delta_ptr_adj, sm_scale, - stride_q_m, stride_q_k, stride_k_n, stride_k_k, stride_v_n, stride_v_k, - stride_dropout_m, stride_dropout_n, - stride_delta_m, - seqlen_q, seqlen_k, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M, MASK_BLOCK_N, - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, - MASK=True, - ENABLE_DROPOUT=ENABLE_DROPOUT, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - end_n -= num_steps * MASK_BLOCK_N - num_steps = tl.cdiv(end_n, BLOCK_N) - start_n = max(end_n - num_steps * BLOCK_N, 0) - dq = _bwd_dq_inner( - dq, - q, k_ptr_adj, v_ptr_adj, do, m, delta_ptr_adj, sm_scale, - stride_q_m, stride_q_k, stride_k_n, stride_k_k, stride_v_n, stride_v_k, - stride_dropout_m, stride_dropout_n, - stride_delta_m, - seqlen_q, seqlen_k, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M, BLOCK_N, - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - # Write back dQ. - offs_dq = (batch_idx * stride_dq_b + - head_q_idx * stride_dq_h + - q_start * stride_dq_m + - offs_m[:, None] * stride_dq_m + - offs_k[None, :] * stride_dq_k) - dq *= sm_scale - tl.store(dq_ptr + offs_dq, dq, mask=mask_q) - - -@triton.jit -def _bwd_kernel_dkdvdq_noncausal( - Q, K, V, sm_scale, DO, DK, DV, DQ, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset, - descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BATCH, - NUM_K_PIDS, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - # workgroup id - wid = tl.program_id(0) # 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 - - # Workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim - # This is in order to avoid contention for the tl.atomic_add (inside _bwd_dkdvdq_inner) that happens between workgroups that share the same batch and head_k. - bid = wid % BATCH - hkid = wid // BATCH % NUM_K_HEADS - pid = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS - - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - - dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - - start_n = pid * BLOCK_N - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_n = start_n + tl.arange(0, BLOCK_N) - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask_kv &= offs_k < BLOCK_D_MODEL - - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - adj_k = (bid * stride_kb + - hkid * stride_kh + - k_start * stride_kn + - offs_n[:, None] * stride_kn + - offs_k[None, :] * stride_kk) - adj_v = (bid * stride_vb + - hkid * stride_vh + - k_start * stride_vn + - offs_n[:, None] * stride_vn + - offs_k[None, :] * stride_vk) - - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) - - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - adj_q = (bid * stride_qb + hqid * stride_qh + q_start * stride_qm) - adj_dq = (bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm) - - Q_ptr = Q + adj_q - DQ_ptr = DQ + adj_dq - - adj_do = (bid * stride_dob + hqid * stride_doh + q_start * stride_dom) - DO_ptr = DO + adj_do - adj_delta = (bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam) - M_ptr = M + adj_delta - Delta_ptr = Delta + adj_delta - - #dropout - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset + bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = dropout_mask + bid * stride_dropoutb + \ - hqid * stride_dropouth - - if IS_FP8: - descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) - descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) - descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - start_m = 0 - num_steps = tl.cdiv(seqlen_q, BLOCK_M) - - dk, dv = _bwd_dkdvdq_inner( - dk, dv, - Q_ptr, k, v, DO_ptr, DQ_ptr, M_ptr, Delta_ptr, sm_scale, - stride_qm, stride_qk, - stride_dqm, stride_dqk, - stride_dom, stride_dok, - stride_dropoutm, stride_dropoutn, - stride_deltam, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - seqlen_q, seqlen_k, - start_n, start_m, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M, BLOCK_N, - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - workgroup_id=pid, - ) - - adj_dkdv = (bid * stride_dkb + - hkid * stride_dkh + - k_start * stride_dkn + offs_n[:, None] * stride_dkn + - offs_k[None, :] * stride_dkk) - tl.store(DV + adj_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(DK + adj_dkdv, dk, mask=mask_kv) - - - -@triton.jit -def _bwd_kernel_dkdv_noncausal( - Q, K, V, sm_scale, DO, DK, DV, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset, - descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) - - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - - dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) - - start_n = pid * BLOCK_N - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_n = start_n + tl.arange(0, BLOCK_N) - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask_kv &= offs_k < BLOCK_D_MODEL - - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - adj_k = (bid * stride_kb + - hkid * stride_kh + - k_start * stride_kn + - offs_n[:, None] * stride_kn + - offs_k[None, :] * stride_kk) - adj_v = (bid * stride_vb + - hkid * stride_vh + - k_start * stride_vn + - offs_n[:, None] * stride_vn + - offs_k[None, :] * stride_vk) - - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) - - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - adj_q = (bid * stride_qb + hqid * stride_qh + q_start * stride_qm) - Q_ptr = Q + adj_q - adj_do = (bid * stride_dob + hqid * stride_doh + q_start * stride_dom) - DO_ptr = DO + adj_do - adj_delta = (bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam) - M_ptr = M + adj_delta - Delta_ptr = Delta + adj_delta - - #dropout - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset + bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = dropout_mask + bid * stride_dropoutb + \ - hqid * stride_dropouth - - if IS_FP8: - descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) - descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) - descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - start_m = 0 - num_steps = tl.cdiv(seqlen_q, BLOCK_M) - dk, dv = _bwd_dkdv_inner( - dk, dv, - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, - stride_qm, stride_qk, - stride_dom, stride_dok, - stride_dropoutm, stride_dropoutn, - stride_deltam, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - seqlen_q, seqlen_k, - start_n, start_m, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M, BLOCK_N, - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - - adj_dkdv = (bid * stride_dkb + - hkid * stride_dkh + - k_start * stride_dkn + offs_n[:, None] * stride_dkn + - offs_k[None, :] * stride_dkk) - tl.store(DV + adj_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(DK + adj_dkdv, dk, mask=mask_kv) - - -@triton.jit -def _bwd_kernel_dq_noncausal( - Q, K, V, sm_scale, DO, DQ, - M, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset_base, - descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, - NUM_Q_HEADS: tl.constexpr, - NUM_K_HEADS: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, -): - pid = tl.program_id(0) #seqlen - bid = tl.program_id(1) #batch - hkid = tl.program_id(2) #head_k - - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - start_m = pid * BLOCK_M - - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - offs_m = start_m + tl.arange(0, BLOCK_M) - - #mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) - if PADDED_HEAD: - mask_k = offs_k < BLOCK_D_MODEL - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk - offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn - K += adj_k - V += adj_v - - GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - delta_ptr = delta + adj_delta - - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = (philox_offset_base + - bid * stride_dropoutb + - hqid * stride_dropouth) - dropout_offset = ( - dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth) - - q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) - m = m[:, None] - - #FP8 - if IS_FP8: - descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) - descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) - descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - start_n = 0 - end_n = seqlen_k - num_steps = tl.cdiv(seqlen_k, BLOCK_N) - dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, - q, K, V, do, m, delta_ptr, sm_scale, - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, - stride_deltam, - seqlen_q, seqlen_k, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M, BLOCK_N, - BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - ) - - adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk - dq *= sm_scale - tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) - -def attention_prefill_backward_triton_fused_atomics_impl( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - philox_seed: Optional[int] = 0, - philox_offset: Optional[int] = 0, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None, - fused: bool = False, - # seqused for FA v3 (currently ignored in this implementation) - seqused_q: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, -): - IS_FP8 = is_fp8(q) - if IS_FP8: - FP8_MAX = torch.finfo(q.dtype).max - descale_strides = (descale_q.stride(0),descale_k.stride(0),descale_v.stride(0),descale_do.stride(0) ) - - if DEBUG: - print(f"FP8 path triggered in bwd_prefill_fused_atomics.py") - else: - FP8_MAX = None - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_do_z = None - descale_strides = (stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z) - - IS_VARLEN = True if cu_seqlens_q is not None else False - - #get strides and shape - if IS_VARLEN: - #Layout for q,k,v is thd ie [total tokens, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] - seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - dq_strides = (0, dq.stride(1), dq.stride(0), dq.stride(2)) - dk_strides = (0, dk.stride(1), dk.stride(0), dk.stride(2)) - dv_strides = (0, dv.stride(1), dv.stride(0), dv.stride(2)) - do_strides = (0, do.stride(1), do.stride(0), do.stride(2)) - else: - #Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = q.shape - seqlen_k, num_k_heads = k.shape[1], k.shape[2] - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - dq_strides = (dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3)) - dk_strides = (dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3)) - dv_strides = (dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3)) - do_strides = (do.stride(0), do.stride(2), do.stride(1), do.stride(3)) - - #BLOCK_D_MODEL, BLOCK_D_MODEL_POW2 - #padding for head_dim. Power of 2 or 16 - BLOCK_D_MODEL_POW2 = triton.next_power_of_2(head_sz) - BLOCK_D_MODEL_POW2 = max(BLOCK_D_MODEL_POW2, 16) - - #Configs - #PRE_BLOCK, BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 - #BLK_SLICE_FACTOR - NUM_WARPS, NUM_STAGES = 4, 1 - WAVES_PER_EU = 1 - PRE_BLOCK = 128 - #BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 - BLK_SLICE_FACTOR = 2 - - #init delta - delta = torch.zeros_like(softmax_lse) - if IS_VARLEN: - #[total_tokens, num_q_heads, seqlen_q] - delta_strides = (0, delta.stride(1), delta.stride(0)) - else: - #[batch, num_q_heads, seqlen_q] - delta_strides = delta.stride() - - #preprocess - #compute D(delta) = rowsum(dO*O). Note, multiplication is element-wise. - pre_grid = (triton.cdiv(max_seqlen_q, PRE_BLOCK), batch, num_q_heads) - _bwd_preprocess[pre_grid]( - o, do, - delta, - *o_strides, - *delta_strides, - descale_strides[3], - cu_seqlens_q, max_seqlen_q, - descale_do, - BLOCK_M=PRE_BLOCK, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8 - ) - - #dropout_mask - use_dropout = (dropout_p > 0.0) - if use_dropout: - dropout_mask = torch.zeros( - (batch, num_q_heads, max_seqlen_q, max_seqlen_k), - device=q.device, - dtype=torch.float32) - dropout_strides = dropout_mask.stride() - else: - dropout_mask = None - dropout_strides = (0, 0, 0, 0) - - grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, num_k_heads) - grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, num_k_heads) - - if fused: # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups - - BLOCK_N = 128 if BLOCK_D_MODEL_POW2 < 160 else 64 # larger head sizes lead to oom - config = { - "BLOCK_M": 32, - "BLOCK_N": BLOCK_N, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 1, - "BLK_SLICE_FACTOR": 2, - } - - num_k_pids = (max_seqlen_k + BLOCK_N - 1) // BLOCK_N - grid_dkdvdq = (batch * num_k_heads * num_k_pids,) - - if causal: - _bwd_kernel_dkdvdq_causal[grid_dkdvdq]( - q, k, v, sm_scale, do, dk, dv, dq, - softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask,dropout_p, philox_seed, philox_offset, - descale_q, descale_k, descale_v, descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BATCH=batch, - NUM_K_PIDS=num_k_pids, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - **config, - ) - else: - _bwd_kernel_dkdvdq_noncausal[grid_dkdvdq]( - q, k, v, sm_scale, do, dk, dv, dq, - softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask,dropout_p, philox_seed, philox_offset, - descale_q, descale_k, descale_v, descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BATCH=batch, - NUM_K_PIDS=num_k_pids, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - **config, - ) - - return delta - - # split kernels solution: one kernel computes dk, dv and the other computes dq - - if causal: - _bwd_kernel_dkdv_causal[grid_dkdv]( - q, k, v, sm_scale, do, dk, dv, - softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask,dropout_p, philox_seed, philox_offset, - descale_q, descale_k, descale_v, descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M1, - BLOCK_N=BLOCK_N1, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - _bwd_kernel_dq_causal[grid_dq]( - q, k, v, sm_scale, do, dq, - softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_mask,dropout_p, philox_seed, philox_offset, - descale_q, descale_k, descale_v, descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M2, - BLOCK_N=BLOCK_N2, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - else: - _bwd_kernel_dkdv_noncausal[grid_dkdv]( - q, k, v, sm_scale, do, dk, dv, - softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask,dropout_p, philox_seed, philox_offset, - descale_q, descale_k, descale_v, descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M1, - BLOCK_N=BLOCK_N1, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - - _bwd_kernel_dq_noncausal[grid_dq]( - q, k, v, sm_scale, do, dq, - softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask,dropout_p, philox_seed, philox_offset, - descale_q, descale_k, descale_v, descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M2, - BLOCK_N=BLOCK_N2, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - - return delta \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py deleted file mode 100755 index 0d3b3a6fdf4..00000000000 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py +++ /dev/null @@ -1,1467 +0,0 @@ -import os -import torch -import triton # type: ignore -import triton.language as tl # type: ignore -from typing import Literal, Optional -from .utils import AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, DEBUG, compute_fp8_scaling_factors, \ - create_dropout_mask, create_dropout_mask_varlen, is_cdna, is_fp8, is_rdna, round_multiple - -# NOTE: triton fails to import tl.constexprs so create them here for the file -tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) -tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) - - -def get_autotune_configs(): - if False: - if is_cdna(): - # shared meta-parameters - NUM_STAGES = 1 - NUM_WARPS = 4 - WAVES_PER_EU = 2 - MATRIX_INSTR_NONKDIM = 16 - - preprocess_autotune_configs = [ - triton.Config({"PRE_BLOCK": 128, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config - triton.Config({"PRE_BLOCK": 64, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - triton.Config({"PRE_BLOCK": 32, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - triton.Config({"PRE_BLOCK": 16, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - preprocess_autotune_keys = [ - "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", - ] - causal_autotune_configs = [ - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config - triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - causal_autotune_keys = [ - "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", - ] - noncausal_autotune_configs = [ - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config - triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - noncausal_autotune_keys = [ - "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", - ] - - return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) - else: - raise ValueError("Unknown Device Type") - else: - # meta-parameters - # TODO: fix num_stages later - NUM_WARPS, NUM_STAGES = 4, 1 - WAVES_PER_EU = 1 - PRE_BLOCK = 128 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLK_SLICE_FACTOR = 2 - - assert BLOCK_N1 == BLOCK_M2 - - # configs for the kernels - preprocess_autotune_configs = [ - triton.Config({"PRE_BLOCK": PRE_BLOCK, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - preprocess_autotune_keys = [ - "max_seqlen_q", - "ACTUAL_HEAD_DIM_V", "IS_VARLEN", - ] - causal_autotune_configs = [ - triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - causal_autotune_keys = [ - "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", - ] - noncausal_autotune_configs = [ - triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - noncausal_autotune_keys = [ - "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", - ] - return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) - - - -(preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) = get_autotune_configs() - - -# This function computes delta given output Out and gradient DO -# Here is the I/O shape: -# Out: (batch, nhead_q, max_seqlens_q, headDim) -# DO: (batch, nhead_q, max_seqlens_q, headDim) -# Delta: (batch, nheads_q, max_seqlens_q) -@triton.autotune( - configs=preprocess_autotune_configs, - key=preprocess_autotune_keys, - use_cuda_graph=True, -) -@triton.jit -def _bwd_preprocess( - O, - DO, # noqa: E741 - Delta, - stride_ob, stride_oh, stride_om, stride_od, - stride_dob, stride_doh, stride_dom, stride_dod, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_descale_do_z, - cu_seqlens_q, max_seqlen_q, - Descale_do, - PRE_BLOCK: tl.constexpr, - HEAD_DIM_V: tl.constexpr, - ACTUAL_HEAD_DIM_V: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr -): - pid_m = tl.program_id(0) - bid = tl.program_id(1) - hid = tl.program_id(2) - # Handle varlen - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - seqlen_q = q_end - q_start - else: - q_start = 0 - seqlen_q = max_seqlen_q - - # Compute offsets - offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) - offs_d = tl.arange(0, HEAD_DIM_V) - # pointer offsets for O & DO - off_o = ( bid * stride_ob - + hid * stride_oh - + q_start * stride_om - + offs_m[:, None] * stride_om - + offs_d[None, :] * stride_od) # noqa: E741 - off_do = (bid * stride_dob - + hid * stride_doh - + q_start * stride_dom - + offs_m[:, None] * stride_dom - + offs_d[None, :] * stride_dod) - - # create masks - mask_m = offs_m < seqlen_q - mask_md = mask_m[:, None] - PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) - if PADDED_HEAD_V: - mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM_V - # load - o = tl.load(O + off_o, mask=mask_md, other=0.0) - do = tl.load(DO + off_do, mask=mask_md, other=0.0) - # compute and write-back to delta - if IS_FP8: - off_descale_do = bid * stride_descale_do_z + hid - descale_do = tl.load(Descale_do + off_descale_do) - - # NOTE: do is in the fp8 range and o is not in fp8 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - off_delta = (bid * stride_delta_b - + hid * stride_delta_h - + q_start * stride_delta_m - + offs_m * stride_delta_m) - tl.store(Delta + off_delta , delta, mask=mask_m) - - -# The main inner-loop logic for computing dK and dV. -@triton.jit -def _bwd_dkdv_inner( - dk, dv, # output - Q, k, v, DO, M, D, sm_scale, # input tensor - stride_qm, stride_qk, - stride_dom, stride_dok, - stride_dropoutm, stride_dropoutn, - stride_lse_m, stride_delta_m, - BLOCK_M: tl.constexpr, # 16 - BLOCK_N: tl.constexpr, # 128 - HEAD_DIM_QK: tl.constexpr, # - HEAD_DIM_V: tl.constexpr, # - ACTUAL_HEAD_DIM_QK: tl.constexpr, # - ACTUAL_HEAD_DIM_V: tl.constexpr, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - seqlen_q, seqlen_k, # max sequence length for q and k - # Filled in by the wrapper. - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal - ENABLE_DROPOUT: tl.constexpr, # activate dropout - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, # activate exp2 - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # if HEAD_DIM is padded - PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) - PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) - offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) - offs_k_qk = tl.arange(0, HEAD_DIM_QK) - offs_k_v = tl.arange(0, HEAD_DIM_V) - # mask to make sure not OOB of seqlen_q - mask_n = offs_n < seqlen_k - # Q and DO are (seqlen_q, head_dim) - # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM_QK, 1), transpose of q - qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k_qk[:, None] * stride_qk - # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM_V), NOT transposed - do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k_v[None, :] * stride_dok - # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. - tl.static_assert(BLOCK_N % BLOCK_M == 0) - curr_m = start_m - step_m = BLOCK_M - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) - - for blk_idx in range(num_steps): - if DEBUG_TRITON: print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 - offs_m = curr_m + tl.arange(0, BLOCK_M) - # update the mask because offs_m advanced - mask_m = offs_m < seqlen_q - mask_qT = mask_m[None, :] - mask_do = mask_m[:, None] - mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) - if PADDED_HEAD_QK: - mask_qT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK - if PADDED_HEAD_V: - mask_do &= offs_k_v[None, :] < ACTUAL_HEAD_DIM_V - qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) - # generate dropout mask - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = curr_philox_offset + \ - offs_m[None, :] * stride_dropoutm + \ - offs_n[:, None] * stride_dropoutn - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = offs_m[None, :] * stride_dropoutm + \ - offs_n[:, None] * stride_dropoutn - dropout_mask = tl.load( - curr_dropout_offset + dropout_offs, - mask=mask_nm - ) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1.0 / (1 - dropout_p) - # Load m before computing qk to reduce pipeline stall. - m = tl.load(M + offs_m * stride_lse_m, mask=mask_m, other=0.0) - if IS_FP8: - qkT = (tl.dot(k, qT) * descale_q * descale_k) - else: - qkT = tl.dot(k, qT) - qkT_scaled = qkT * sm_scale - - if USE_ALIBI: - relative_pos_block = offs_n[:, None] + seqlen_q - seqlen_k - offs_m[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - qkT_scaled += alibi_block - - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"qT: {qT.shape}\n", qT) - print(f"k: {k.shape}\n", k) - print(f"qkT scaled: {qkT.shape}\n", qkT_scaled) - # TODO: remove the scaling of m later when we removed re-scaling in fwd - if USE_EXP2: - pT = tl.math.exp2(qkT_scaled * RCP_LN2 - m[None, :] * RCP_LN2) - else: - pT = tl.math.exp(qkT_scaled - m[None, :]) - - # Autoregressive masking. - if MASK: - # offset offs_m with delta_qk since the causal mask starts at - # bottom right of the (seqlen_q, seqlen_k) matrix - causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] - mask = causal_mask & mask_nm - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"causal_mask: {causal_mask.shape}\n", causal_mask) - print(f"qkT after causal: {qkT.shape}\n", tl.where(causal_mask, qkT * sm_scale, 0.0)) - pT = tl.where(mask, pT, 0.0) - do = tl.load(do_ptrs, mask=mask_do, other=0.0) - # Compute dV. - if ENABLE_DROPOUT: - pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) - dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do)* descale_p_dropout * descale_do) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) - else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) - - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"pT: {pT.shape}\n", pT) - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(D + offs_m * stride_delta_m, mask=mask_m) - # Compute dP and dS. - if IS_FP8: - dpT = (tl.dot(v, tl.trans(do)) * descale_v * descale_do) - else: - dpT = tl.dot(v, tl.trans(do)) - if ENABLE_DROPOUT: - dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale - delta_i = Di[None, :] - dsT = pT * (dpT - delta_i) - if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) - else: - dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) - # Increment pointers. - curr_m += step_m - qT_ptrs += step_m * stride_qm - do_ptrs += step_m * stride_dom - return dk, dv - -# the main inner-loop logic for computing dQ -@triton.jit -def _bwd_dq_inner( - dq, # output - q, K, V, do, m, Delta, sm_scale, # input - # shared by Q/K/V. - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, # stride for dropout - stride_lse_m, - stride_delta_m, - seqlen_q, seqlen_k, # - BLOCK_M2: tl.constexpr, # - BLOCK_N2: tl.constexpr, # - HEAD_DIM_QK: tl.constexpr, - HEAD_DIM_V: tl.constexpr, - ACTUAL_HEAD_DIM_QK: tl.constexpr, - ACTUAL_HEAD_DIM_V: tl.constexpr, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - # Filled in by the wrapper. - start_m, start_n, end_n, num_steps, # - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # if HEAD_DIM is padded - PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) - PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M2) - offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k_qk = tl.arange(0, HEAD_DIM_QK) - offs_k_v = tl.arange(0, HEAD_DIM_V) - - # mask to make sure not OOB of seqlen_q - mask_m = offs_m < seqlen_q - - kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k_qk[:, None] * stride_kk - vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k_v[:, None] * stride_vk - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(Delta + offs_m * stride_delta_m, mask=mask_m, other=0.0) - # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. - tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - curr_n = start_n - step_n = BLOCK_N2 - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) - for blk_idx in range(num_steps): - if DEBUG_TRITON: print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 - offs_n = curr_n + tl.arange(0, BLOCK_N2) - # end_n is needed because the end of causal True might not be perfectly - # aligned with the end of the block - mask_n = offs_n < end_n - if DEBUG_TRITON_DETAIL: print(f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}") # noqa: E701 - if DEBUG_TRITON_DETAIL: print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 - mask_kT = mask_n[None, :] - mask_vT = mask_n[None, :] - mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) - if PADDED_HEAD_QK: - mask_kT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK - if PADDED_HEAD_V: - mask_vT &= offs_k_v[:, None] < ACTUAL_HEAD_DIM_V - - kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) - vT = tl.load(vT_ptrs, mask=mask_vT, other=0.0) - - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = curr_philox_offset + \ - offs_m[:, None] * stride_dropoutm + \ - offs_n[None, :] * stride_dropoutn - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = offs_m[:, None] * stride_dropoutm + \ - offs_n[None, :] * stride_dropoutn - dropout_mask = tl.load( - curr_dropout_offset + dropout_offs, - mask=mask_mn) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1 / (1 - dropout_p) - - if IS_FP8: - qk = (tl.dot(q, kT) * descale_q * descale_k) - else: - qk = tl.dot(q, kT) - qk_scaled = qk * sm_scale - - if USE_ALIBI: - relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - qk_scaled += alibi_block - - if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk_scaled) # noqa: E701 - if USE_EXP2: - p = tl.math.exp2(qk_scaled * RCP_LN2 - m * RCP_LN2) - else: - p = tl.math.exp(qk_scaled - m) - - # Autoregressive masking. - if MASK: - causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] - mask = causal_mask & mask_mn - p = tl.where(mask, p, 0.0) - # Compute dP and dS. - if IS_FP8: - dp = (tl.dot(do, vT) * descale_do * descale_v) - else: - dp = tl.dot(do, vT) - if ENABLE_DROPOUT: - dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale - delta_i = Di[:, None] - ds = p * (dp -delta_i) - # Compute dQ. - # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. - if IS_FP8: - scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) - dq += (tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) * descale_ds * descale_k) - else: - dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) - # Increment pointers. - curr_n += step_n - kT_ptrs += step_n * stride_kn - vT_ptrs += step_n * stride_vn - return dq - -@triton.autotune( - configs=causal_autotune_configs, - key=causal_autotune_keys, - use_cuda_graph=True, -) -@triton.jit -def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), batch) - Q, K, V, sm_scale, DO, DQ, DK, DV, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qd, - stride_kb, stride_kh, stride_kn, stride_kd, - stride_vb, stride_vh, stride_vn, stride_vd, - stride_dqb, stride_dqh, stride_dqm, stride_dqd, - stride_dkb, stride_dkh, stride_dkn, stride_dkd, - stride_dvb, stride_dvh, stride_dvn, stride_dvd, - stride_lse_b, stride_lse_h, stride_lse_m, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_dob, stride_doh, stride_dom, stride_dod, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - HQ, HK, - cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, # Add seqused parameters - max_seqlen_q, max_seqlen_k, - Dropout_mask, dropout_p, philox_seed, philox_offset_base, - Alibi_slopes, - Descale_q, Descale_k, Descale_v, Descale_do, - BLOCK_M1: tl.constexpr, - BLOCK_N1: tl.constexpr, - BLOCK_M2: tl.constexpr, - BLOCK_N2: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM_QK: tl.constexpr, - HEAD_DIM_V: tl.constexpr, - ACTUAL_HEAD_DIM_QK: tl.constexpr, - ACTUAL_HEAD_DIM_V: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, - USE_SEQUSED: tl.constexpr, # Add flag for seqused - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # program ids - hkid = tl.program_id(0) - pid = tl.program_id(1) - bid = tl.program_id(2) - if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - - # If seqused is provided, use it to limit the actual sequence length - if USE_SEQUSED: - actual_seqlen_q = tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start - seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) - actual_seqlen_k = tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start - seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) - else: - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - delta_qk = seqlen_q - seqlen_k - if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701 - PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) - PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) - offs_d_qk = tl.arange(0, HEAD_DIM_QK) - offs_d_v = tl.arange(0, HEAD_DIM_V) - GROUP_SIZE: tl.constexpr = HQ // HK - - # align the delta_qk - start_n = pid * BLOCK_N1 - if start_n < seqlen_k: - # This section does dk and dv - dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) - dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) - - # q > k: diretcly skip all the way until the start of causal block - start_delta_q_gt_k = delta_qk - # q < k: some blocks will have no Masked block, other needs to re-calc - # starting position - # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the - # masked op - num_blocks_skip = -delta_qk // BLOCK_N1 - delta_aligned = (num_blocks_skip + 1) * BLOCK_N1 + delta_qk - start_delta_q_lt_k = delta_aligned // BLOCK_M1 * BLOCK_M1 - if delta_qk >= 0: - start_delta = delta_qk - if DEBUG_TRITON: print(f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}") # noqa: E701 - else: - start_delta = start_delta_q_lt_k - if DEBUG_TRITON: print(f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}") # noqa: E701 - - offs_n = start_n + tl.arange(0, BLOCK_N1) - # Mask for loading K and V - mask_k = offs_n[:, None] < seqlen_k - mask_v = offs_n[:, None] < seqlen_k - if PADDED_HEAD_QK: - mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK - mask_k &= mask_d_qk[None, :] - if PADDED_HEAD_V: - mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V - mask_v &= mask_d_v[None, :] - - # K/V tensors not changed for the group - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d_qk[None, :] * stride_kd - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d_v[None, :] * stride_vd - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_k, other=0.0) - v = tl.load(V + adj_v, mask=mask_v, other=0.0) - # If MQA / GQA, set the K and V head offsets appropriately. - # hqid = hkid - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - if delta_qk >= 0: - start_m = start_n + start_delta - len_m = BLOCK_N1 - else: - start_m = max(start_n + delta_qk, 0) - start_m = start_m // BLOCK_M1 * BLOCK_M1 - # because we might shift the masked blocks up, we are deeper into - # the masked out region, so we would potentially increase the total - # steps with masked operation to get out of it - residue_m = max(start_n + delta_qk - start_m, 0) - len_m = BLOCK_N1 + residue_m - if DEBUG_TRITON: print(f"residue_m = {residue_m}") # noqa: E701 - - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - Q_ptr = Q + adj_q - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - DO_ptr = DO + adj_do - adj_delta = bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m - Delta_ptr = Delta + adj_delta - adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m - M_ptr = M + adj_m - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = Dropout_mask + bid * stride_dropoutb + \ - hqid * stride_dropouth - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR - # bound the masked operation to q len so it does not have to wast cycles - len_m = min(len_m, seqlen_q) - num_steps = tl.cdiv(len_m, MASK_BLOCK_M1) - # when q < k, we may skip the initial masked op - if pid < num_blocks_skip: - num_steps = 0 - - # if start_m is negative, the current N-tile has no block on the - # diagonal of causal mask, so everything have no causal mask - if DEBUG_TRITON: print(f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}") # noqa: E701 - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qd, # strides for q - stride_dom, stride_dod, # strides for o - stride_dropoutm, stride_dropoutn, # strides for dropout - stride_lse_m, stride_delta_m, - MASK_BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, - MASK=True, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - start_m += num_steps * MASK_BLOCK_M1 - num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M1) - end_m = start_m + num_steps * BLOCK_M1 - - if DEBUG_TRITON: print(f"start_m after Masked step: {start_m}; num_steps: {num_steps}") # noqa: E701 - if DEBUG_TRITON: print(f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}") # noqa: E701 - if DEBUG_TRITON: print("unMasked") # noqa: E701 - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qd, # strides for q - stride_dom, stride_dod, # strides for o - stride_dropoutm, stride_dropoutn, # strides for dropout - stride_lse_m, stride_delta_m, - BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - # end of GQA/MQA of dkdv - # Write back dV - adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn - offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd - tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) - # write back dk - adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn - offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd - dk *= sm_scale - tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) - - # This part does dq - start_m = pid * BLOCK_M2 - if start_m < seqlen_q: - # seqlen_q > seqlen_k, no need to process these tile for dq - if DEBUG_TRITON: print(f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2}") # noqa: E701 - if start_m + BLOCK_M2 < delta_qk: - if DEBUG_TRITON: print(f"start_m + BLOCK_M2 = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2} < delta_qk of {delta_qk}") # noqa: E701 - return - - offs_m = start_m + tl.arange(0, BLOCK_M2) - # Mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - mask_do = offs_m[:, None] < seqlen_q - if PADDED_HEAD_QK: - mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK - mask_q &= mask_d_qk[None, :] - if PADDED_HEAD_V: - mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V - mask_do &= mask_d_v[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd - offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod - # NOTE: don't assume that the strides for k and v are the same! - K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn - V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn - - # If MQA / GQA, set the K and V head offsets appropriately. - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front - # for every M-tile - end_n = start_m + BLOCK_M2 - delta_qk - # clamp end_n at [0, seqlen_k] - end_n = max(min(end_n, seqlen_k), 0) - if DEBUG_TRITON: print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - adj_delta = \ - bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m - Delta_ptr = Delta + adj_delta - adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m - M_ptr = M + adj_m - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + \ - bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = \ - Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) - m = tl.load(M + adj_m + offs_m * stride_lse_m, - mask=offs_m < seqlen_q) - m = m[:, None] - - MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR - # start can only be 0 at minimum - start_n = max(end_n - BLOCK_M2, 0) - num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, - q, K, V, do, m, Delta_ptr, sm_scale, - stride_qm, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd, - stride_dropoutm, stride_dropoutn, - stride_lse_m, - stride_delta_m, - seqlen_q, seqlen_k, - BLOCK_M2, MASK_BLOCK_N2, - HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - MASK=True, # - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - end_n -= num_steps * MASK_BLOCK_N2 - num_steps = tl.cdiv(end_n, BLOCK_N2) - start_n = max(end_n - num_steps * BLOCK_N2, 0) - if DEBUG_TRITON: print(f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 - dq = _bwd_dq_inner( - dq, - q, K, V, do, m, Delta_ptr, sm_scale, - stride_qm, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd, - stride_dropoutm, stride_dropoutn, - stride_lse_m, - stride_delta_m, - seqlen_q, seqlen_k, - BLOCK_M2, BLOCK_N2, - HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - # Write back dQ. - adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd - dq *= sm_scale - tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) - # end of GQA/MQA of dq - -@triton.autotune( - configs=noncausal_autotune_configs, - key=noncausal_autotune_keys, - use_cuda_graph=True, -) -@triton.jit -def bwd_kernel_noncausal( - Q, K, V, sm_scale, DO, DQ, DK, DV, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qd, - stride_kb, stride_kh, stride_kn, stride_kd, - stride_vb, stride_vh, stride_vn, stride_vd, - stride_dqb, stride_dqh, stride_dqm, stride_dqd, - stride_dkb, stride_dkh, stride_dkn, stride_dkd, - stride_dvb, stride_dvh, stride_dvn, stride_dvd, - stride_lse_b, stride_lse_h, stride_lse_m, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_dob, stride_doh, stride_dom, stride_dod, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - HQ, HK, - cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, # Add seqused parameters - max_seqlen_q, max_seqlen_k, - Dropout_mask, dropout_p, philox_seed, philox_offset_base, - Alibi_slopes, - Descale_q, Descale_k, Descale_v, Descale_do, - BLOCK_M1: tl.constexpr, # 32 - BLOCK_N1: tl.constexpr, # 128 - BLOCK_M2: tl.constexpr, # 128 - BLOCK_N2: tl.constexpr, # 32 - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM_QK: tl.constexpr, - HEAD_DIM_V: tl.constexpr, - ACTUAL_HEAD_DIM_QK: tl.constexpr, - ACTUAL_HEAD_DIM_V: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, - USE_SEQUSED: tl.constexpr, # Add flag for seqused - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # program ids - hkid = tl.program_id(0) - pid = tl.program_id(1) - bid = tl.program_id(2) - if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - - # If seqused is provided, use it to limit the actual sequence length - if USE_SEQUSED: - actual_seqlen_q = tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start - seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) - actual_seqlen_k = tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start - seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) - else: - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) - PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) - offs_d_qk = tl.arange(0, HEAD_DIM_QK) - offs_d_v = tl.arange(0, HEAD_DIM_V) - GROUP_SIZE: tl.constexpr = HQ // HK - - start_n = pid * BLOCK_N1 - if start_n < seqlen_k: - dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) - dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) - - offs_n = start_n + tl.arange(0, BLOCK_N1) - # Mask for loading K and V - mask_k = offs_n[:, None] < seqlen_k - mask_v = offs_n[:, None] < seqlen_k - if PADDED_HEAD_QK: - mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK - mask_k &= mask_d_qk[None, :] - if PADDED_HEAD_V: - mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V - mask_v &= mask_d_v[None, :] - # NOTE: don't assume that the strides for k and v are the same! - # K/V tensors not changed for the group - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d_qk[None, :] * stride_kd - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d_v[None, :] * stride_vd - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_k, other=0.0) - v = tl.load(V + adj_v, mask=mask_v, other=0.0) - # If MQA / GQA, set the K and V head offsets appropriately. - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - Q_ptr = Q + adj_q - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - DO_ptr = DO + adj_do - adj_delta = bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m - Delta_ptr = Delta + adj_delta - adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m - M_ptr = M + adj_m - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = Dropout_mask + bid * stride_dropoutb + \ - hqid * stride_dropouth - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # because there is no causal, we always start from the beginning - start_m = 0 - num_steps = tl.cdiv(seqlen_q, BLOCK_M1) - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qd, # strides for q - stride_dom, stride_dod, # strides for o - stride_dropoutm, stride_dropoutn, # strides for dropout - stride_lse_m, - stride_delta_m, - BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - alibi_slope, - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - # Write back dV - adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn - offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd - tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) - # write back dk - adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn - offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd - dk *= sm_scale - tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) - - # THIS PART DOES DQ - start_m = pid * BLOCK_M2 - if start_m < seqlen_q: - offs_m = start_m + tl.arange(0, BLOCK_M2) - # Mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - mask_do = offs_m[:, None] < seqlen_q - if PADDED_HEAD_QK: - mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK - mask_q &= mask_d_qk[None, :] - if PADDED_HEAD_V: - mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V - mask_do &= mask_d_v[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd - offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod - K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn - V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn - # If MQA / GQA, set the K and V head offsets appropriately. - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - adj_delta = \ - bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m - Delta_ptr = Delta + adj_delta - adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m - M_ptr = M + adj_m - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + \ - bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = \ - Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - - q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) - m = tl.load(M + adj_m + offs_m * stride_lse_m, - mask=offs_m < seqlen_q) - m = m[:, None] - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # start can only be 0 at minimum - start_n = 0 - end_n = seqlen_k - num_steps = tl.cdiv(seqlen_k, BLOCK_N2) - - dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, - q, K, V, do, m, Delta_ptr, sm_scale, - stride_qm, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd, - stride_dropoutm, stride_dropoutn, - stride_lse_m, - stride_delta_m, - seqlen_q, seqlen_k, - BLOCK_M2, BLOCK_N2, - HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - # Write back dQ. - adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd - dq *= sm_scale - tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) - -def is_contiguous(x, name): - if x.is_contiguous(): - return x - else: - print(f"{name} is not contiguous") - return x.contiguous() - -OLD_LSE: bool = False -DEBUG_TRITON: bool = False -DEBUG_TRITON_DETAIL: bool = False - -def attention_prefill_backward_triton_split_fused_no_atomics_impl( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - layout: Literal["bshd", "bhsd", "thd"], - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - use_exp2: bool, - # fp8 - descale_q: Optional[torch.Tensor], - descale_k: Optional[torch.Tensor], - descale_v: Optional[torch.Tensor], - descale_o: Optional[torch.Tensor], - descale_do: Optional[torch.Tensor], - descale_dq: Optional[torch.Tensor], - descale_dk: Optional[torch.Tensor], - descale_dv: Optional[torch.Tensor], - # seqused for FA v3 - seqused_q: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, -): - # get params, strides and shape - IS_VARLEN = layout == "thd" - use_dropout = (dropout_p > 0.0) - - # common assertions - assert 0.0 <= dropout_p <= 1.0, f"dropout_p must be between 0 and 1, got {dropout_p}" - assert q.device == k.device == v.device == o.device == do.device == softmax_lse.device, \ - f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}, do={do.device}, softmax_lse={softmax_lse.device}" - assert q.dtype == k.dtype == v.dtype == do.dtype, "q, k, v, do must have the same dtype" - current_device = torch.cuda.current_device() - assert q.is_cuda and q.device.index == current_device, f"Device mismatch: Kernel will launch on cuda:{current_device}, but tensors are on {q.device}" - - # get shapes and strides - if IS_VARLEN: - # shape - total_seqlen_q, nheads_q, head_size_q = q.shape - total_seqlen_k, nheads_k, head_size_k = k.shape - total_seqlen_v, nheads_v, head_size_v = v.shape - nheads_lse, total_seqlen_lse = softmax_lse.shape - - # assert shapes - assert total_seqlen_lse == total_seqlen_q, f"softmax_lse seqlen {total_seqlen_lse} != q seqlen {total_seqlen_q}" - assert cu_seqlens_q is not None, "cu_seqlens_q must be provided for varlen layout" - assert cu_seqlens_k is not None, "cu_seqlens_k must be provided for varlen layout" - assert max_seqlen_q is not None, "max_seqlen_q must be provided for varlen layout" - assert max_seqlen_k is not None, "max_seqlen_k must be provided for varlen layout" - - # assert head dimensions - assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" - assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" - assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" - assert nheads_lse == nheads_q, f"softmax_lse heads {nheads_lse} != q heads {nheads_q}" - - # assert output shapes - assert o.shape == (total_seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" - assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" - assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" - assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" - assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}" - - # assert cu_seqlens - assert cu_seqlens_q.dtype == torch.int32, f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" - assert cu_seqlens_k.dtype == torch.int32, f"cu_seqlens_k must be int32, got {cu_seqlens_k.dtype}" - assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" - assert cu_seqlens_k[0] == 0, "cu_seqlens_k must start with 0" - assert cu_seqlens_q[-1] == total_seqlen_q, f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} != total_seqlen_q {total_seqlen_q}" - assert cu_seqlens_k[-1] == total_seqlen_k, f"cu_seqlens_k[-1] {cu_seqlens_k[-1]} != total_seqlen_k {total_seqlen_k}" - - # set vars - batch = len(cu_seqlens_q) - 1 - head_size_qk = head_size_q - - # strides - stride_qb, stride_qm, stride_qh, stride_qd = 0, q.stride(0), q.stride(1), q.stride(2) - stride_kb, stride_kn, stride_kh, stride_kd = 0, k.stride(0), k.stride(1), k.stride(2) - stride_vb, stride_vn, stride_vh, stride_vd = 0, v.stride(0), v.stride(1), v.stride(2) - stride_ob, stride_om, stride_oh, stride_od = 0, o.stride(0), o.stride(1), o.stride(2) - stride_dqb, stride_dqm, stride_dqh, stride_dqd = 0, dq.stride(0), dq.stride(1), dq.stride(2) - stride_dkb, stride_dkn, stride_dkh, stride_dkd = 0, dk.stride(0), dk.stride(1), dk.stride(2) - stride_dvb, stride_dvn, stride_dvh, stride_dvd = 0, dv.stride(0), dv.stride(1), dv.stride(2) - stride_dob, stride_dom, stride_doh, stride_dod = 0, do.stride(0), do.stride(1), do.stride(2) - stride_lse_b, stride_lse_h, stride_lse_m = (0, softmax_lse.stride(0), softmax_lse.stride(1)) - else: - # shapes - batch_q, seqlen_q, nheads_q, head_size_q = q.shape - batch_k, seqlen_k, nheads_k, head_size_k = k.shape - batch_v, seqlen_v, nheads_v, head_size_v = v.shape - batch_lse, nheads_lse, seqlen_lse = softmax_lse.shape - - # assert batch dimensions - assert batch_q == batch_k == batch_v, f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" - - # assert head dimensions - assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" - assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" - assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" - - # assert sequence lengths - assert seqlen_k == seqlen_v, f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" - - # assert output shapes - assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected" - assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" - assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" - assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" - assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}" - - # assert softmax_lse shape - assert softmax_lse.shape == (batch_q, nheads_q, seqlen_q), f"softmax_lse shape {softmax_lse.shape} != expected" - - # set vars - batch = batch_q - head_size_qk = head_size_q - max_seqlen_q = seqlen_q - max_seqlen_k = seqlen_k - - # strides - stride_qb, stride_qm, stride_qh, stride_qd = q.stride() - stride_kb, stride_kn, stride_kh, stride_kd = k.stride() - stride_vb, stride_vn, stride_vh, stride_vd = v.stride() - stride_ob, stride_om, stride_oh, stride_od = o.stride() - stride_dqb, stride_dqm, stride_dqh, stride_dqd = dq.stride() - stride_dkb, stride_dkn, stride_dkh, stride_dkd = dk.stride() - stride_dvb, stride_dvn, stride_dvh, stride_dvd = dv.stride() - stride_dob, stride_dom, stride_doh, stride_dod = do.stride() - stride_lse_b, stride_lse_h, stride_lse_m = softmax_lse.stride() - - # fp8 setup - moved after all assertions - IS_FP8 = is_fp8(q) - if IS_FP8: - FP8_MAX = torch.finfo(q.dtype).max - # we already asserted that do, q, k, v all have the same dtype, so no need to check each one - if is_fp8(o): - FP8_OUTPUT = True - assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." - assert descale_dq is not None, f"descale_dq is None. In fp8, you need to pass a tensor for descale_dq along with a tensor dq." - assert descale_dk is not None, f"descale_dk is None. In fp8, you need to pass a tensor for descale_dk along with a tensor dk." - assert descale_dv is not None, f"descale_dv is None. In fp8, you need to pass a tensor for descale_dv along with a tensor dv." - else: - FP8_OUTPUT = False - - stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None - stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None - stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None - stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None - stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None - - if DEBUG: - print(f"FP8 path triggered in bwd_prefill_fused_no_atomics.py (FP8_OUTPUT={FP8_OUTPUT})") - else: - FP8_MAX = None - FP8_OUTPUT = False - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = stride_descale_do_z = None - - # alibi setup - use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) - - # get closest power of 2 over or equal to 32. - padded_d_model_qk = 1 << (head_size_qk - 1).bit_length() - padded_d_model_qk = max(padded_d_model_qk, 32) - padded_d_model_v = 1 << (head_size_v - 1).bit_length() - padded_d_model_v = max(padded_d_model_v, 32) - HEAD_DIM_QK = padded_d_model_qk - HEAD_DIM_V = padded_d_model_v - ACTUAL_HEAD_DIM_QK = head_size_qk - ACTUAL_HEAD_DIM_V = head_size_v - - # init delta - if OLD_LSE: - delta = torch.empty_like(softmax_lse) - if IS_VARLEN: - stride_delta_b, stride_delta_h, stride_delta_m = 0, delta.stride(0), delta.stride(1) - else: - stride_delta_b, stride_delta_h, stride_delta_m = delta.stride() - else: - if IS_VARLEN: - # interface expects the varlen sequence dims to rounded like this. Not sure why. - total_q, num_heads, _ = q.shape - total_q_rounded = total_q + 128 * batch - delta_padded = torch.zeros((nheads_q, total_q_rounded), device=q.device, dtype=torch.float32) - delta = delta_padded[:, :total_q] - stride_delta_b, stride_delta_h, stride_delta_m = 0, delta.stride(0), delta.stride(1) - else: - # the interface expects the sequence dimension to be rounded to 128 - max_seqlen_q_rounded = round_multiple(max_seqlen_q, 128) - delta_padded = torch.zeros((batch, nheads_q, max_seqlen_q_rounded), - device=q.device, dtype=torch.float32) - delta = delta_padded[:, :, :max_seqlen_q] - stride_delta_b, stride_delta_h, stride_delta_m = delta.stride() - - pre_grid = lambda META: (triton.cdiv(max_seqlen_q, META['PRE_BLOCK']), batch, nheads_q) - _bwd_preprocess[pre_grid]( - o, do, - delta, - stride_ob, stride_oh, stride_om, stride_od, - stride_dob, stride_doh, stride_dom, stride_dod, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_descale_do_z, - cu_seqlens_q, max_seqlen_q, - descale_do, - HEAD_DIM_V=HEAD_DIM_V, - ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8 - ) - - if False: - print("delta:", delta, delta.shape) - - # dropout mask tensor for debugging. We dump the dropout mask created in - # the kernel for testing - dropout_mask = None - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ - (0, 0 , 0 , 0) - if use_dropout: - dropout_mask = torch.zeros( - (batch, nheads_q, max_seqlen_q, max_seqlen_k), - device=q.device, - dtype=torch.float32 - ) - - if DROPOUT_USE_PYTORCH: - if not IS_VARLEN: - dropout_mask = create_dropout_mask( - dropout_p, - (batch, nheads_q, max_seqlen_q, max_seqlen_k), - seed = philox_seed - ) - else: - dropout_mask = create_dropout_mask_varlen( - dropout_p, batch, nheads_q, - cu_seqlens_q, cu_seqlens_k, philox_seed - ) - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ - dropout_mask.stride() - - seqlen = max(max_seqlen_q, max_seqlen_k) - grid = lambda META: (nheads_k, (seqlen + META['BLOCK_N1'] - 1) // META['BLOCK_N1'], batch, ) - if causal: - if DEBUG_TRITON: print(f"bwd_kernel: grid = {grid}" ) # noqa: E701 - bwd_kernel_causal[grid]( - q, k, v, sm_scale, do, dq, dk, dv, - softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qd, - stride_kb, stride_kh, stride_kn, stride_kd, - stride_vb, stride_vh, stride_vn, stride_vd, - stride_dqb, stride_dqh, stride_dqm, stride_dqd, - stride_dkb, stride_dkh, stride_dkn, stride_dkd, - stride_dvb, stride_dvh, stride_dvn, stride_dvd, - stride_lse_b, stride_lse_h, stride_lse_m, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_dob, stride_doh, stride_dom, stride_dod, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - nheads_q, nheads_k, - cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, # Pass seqused tensors - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset, - alibi_slopes, - descale_q, descale_k, descale_v, descale_do, - HEAD_DIM_QK=HEAD_DIM_QK, - HEAD_DIM_V=HEAD_DIM_V, - ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, - ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, - USE_SEQUSED=(seqused_q is not None or seqused_k is not None), # Add flag for seqused - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - else: - bwd_kernel_noncausal[grid]( - q, k, v, sm_scale, do, dq, dk, dv, - softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qd, - stride_kb, stride_kh, stride_kn, stride_kd, - stride_vb, stride_vh, stride_vn, stride_vd, - stride_dqb, stride_dqh, stride_dqm, stride_dqd, - stride_dkb, stride_dkh, stride_dkn, stride_dkd, - stride_dvb, stride_dvh, stride_dvn, stride_dvd, - stride_lse_b, stride_lse_h, stride_lse_m, - stride_delta_b, stride_delta_h, stride_delta_m, - stride_dob, stride_doh, stride_dom, stride_dod, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - nheads_q, nheads_k, - cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, # Pass seqused tensors - max_seqlen_q, max_seqlen_k, - dropout_mask, dropout_p, philox_seed, philox_offset, - alibi_slopes, - descale_q, descale_k, descale_v, descale_do, - HEAD_DIM_QK=HEAD_DIM_QK, - HEAD_DIM_V=HEAD_DIM_V, - ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, - ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, - USE_SEQUSED=(seqused_q is not None or seqused_k is not None), # Add flag for seqused - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - if OLD_LSE: - return delta - else: - return delta_padded diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py deleted file mode 100755 index cfff39bf8bf..00000000000 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py +++ /dev/null @@ -1,1360 +0,0 @@ -import torch -import triton # type: ignore -import triton.language as tl # type: ignore -from typing import Literal, Optional -from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, DEBUG, compute_fp8_scaling_factors, get_shapes_from_layout, \ - get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_fp8 - -# NOTE: triton fails to import tl.constexprs so create them here for the file -tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) -tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) - -# This function computes delta given output Out and gradient DO -# Here is the I/O shape: -# Out: (batch, nhead_q, max_seqlens_q, headDim) -# DO: (batch, nhead_q, max_seqlens_q, headDim) -# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at -# fwd_prefill.py line 607 -@triton.jit -def _bwd_preprocess( - O, DO, # noqa: E741 - Delta, - stride_ob, stride_oh, stride_om, stride_ok, - stride_deltab, stride_deltah, stride_deltam, - stride_descale_do_z, - cu_seqlens_q, max_seqlen_q, - Descale_do, - BLOCK_M: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr -): - pid_m = tl.program_id(0) - bid = tl.program_id(1) - hid = tl.program_id(2) - # Handle varlen - q_start = 0 - seqlen_q = max_seqlen_q - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - seqlen_q = q_end - q_start - else: - q_start = 0 - seqlen_q = max_seqlen_q - - # Compute offsets - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_k = tl.arange(0, HEAD_DIM) - # Offset O/DO by batch, head and q_start - O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741 - DO += bid * stride_ob + hid * stride_oh + q_start * stride_om - # create masks - mask_m = offs_m < seqlen_q - mask_md = mask_m[:, None] - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_md &= offs_k[None, :] < ACTUAL_HEAD_DIM - # compute pointers - offs_do = offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok - out_ptrs = O + offs_do - do_ptrs = DO + offs_do - # load - o = tl.load(out_ptrs, mask=mask_md, other=0.0) - do = tl.load(do_ptrs, mask=mask_md, other=0.0) - # compute and write-back to delta - if IS_FP8: - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hid) - - # NOTE: do is in the fp8 range and o is not in fp8 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - delta_offset = Delta + bid * stride_deltab + hid * stride_deltah + q_start * stride_deltam - tl.store(delta_offset + offs_m * stride_deltam, delta, mask=mask_m) - - -# The main inner-loop logic for computing dK and dV. -@triton.jit -def _bwd_dkdv_inner( - dk, dv, # output - Q, k, v, DO, M, D, sm_scale, # input tensor - stride_qm, stride_qk, - stride_dom, stride_dok, - stride_dropoutm, stride_dropoutn, - stride_deltam, - BLOCK_M: tl.constexpr, # 16 - BLOCK_N: tl.constexpr, # 128 - HEAD_DIM: tl.constexpr, # - ACTUAL_HEAD_DIM: tl.constexpr, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - seqlen_q, seqlen_k, # max sequence length for q and k - # Filled in by the wrapper. - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal - ENABLE_DROPOUT: tl.constexpr, # activate dropout - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, # activate exp2 - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) - offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) - offs_k = tl.arange(0, HEAD_DIM) - # mask to make sure not OOB of seqlen_q - mask_n = offs_n < seqlen_k - # Q and DO are (seqlen_q, head_dim) - # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q - qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk - # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed - do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. - tl.static_assert(BLOCK_N % BLOCK_M == 0) - curr_m = start_m - step_m = BLOCK_M - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) - - for blk_idx in range(num_steps): - if DEBUG_TRITON: print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 - offs_m = curr_m + tl.arange(0, BLOCK_M) - # update the mask because offs_m advanced - mask_m = offs_m < seqlen_q - mask_qT = mask_m[None, :] - mask_do = mask_m[:, None] - mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) - if PADDED_HEAD: - mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM - mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM - qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) - # generate dropout mask - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = curr_philox_offset + \ - offs_m[None, :] * stride_dropoutm + \ - offs_n[:, None] * stride_dropoutn - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = offs_m[None, :] * stride_dropoutm + \ - offs_n[:, None] * stride_dropoutn - dropout_mask = tl.load( - curr_dropout_offset + dropout_offs, - mask=mask_nm - ) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1.0 / (1 - dropout_p) - # Load m before computing qk to reduce pipeline stall. - m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) - if IS_FP8: - qkT = (tl.dot(k, qT) * descale_q * descale_k) - else: - qkT = tl.dot(k, qT) - qkT_scaled = qkT * sm_scale - - if USE_ALIBI: - relative_pos_block = offs_n[:, None] + seqlen_q - seqlen_k - offs_m[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - qkT_scaled += alibi_block - - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"qT: {qT.shape}\n", qT) - print(f"k: {k.shape}\n", k) - print(f"qkT scaled: {qkT.shape}\n", qkT_scaled) - # TODO: remove the scaling of m later when we removed re-scaling in fwd - if USE_EXP2: - pT = tl.math.exp2(qkT_scaled * RCP_LN2 - m[None, :] * RCP_LN2) - else: - pT = tl.math.exp(qkT_scaled - m[None, :]) - - # Autoregressive masking. - if MASK: - # offset offs_m with delta_qk since the causal mask starts at - # bottom right of the (seqlen_q, seqlen_k) matrix - causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] - mask = causal_mask & mask_nm - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"causal_mask: {causal_mask.shape}\n", causal_mask) - print(f"qkT after causal: {qkT.shape}\n", tl.where(causal_mask, qkT * sm_scale, 0.0)) - pT = tl.where(mask, pT, 0.0) - do = tl.load(do_ptrs, mask=mask_do, other=0.0) - # Compute dV. - if ENABLE_DROPOUT: - pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) - dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do)* descale_p_dropout * descale_do) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) - else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) - - if DEBUG_TRITON_DETAIL: - if start_n == 256: - print(f"pT: {pT.shape}\n", pT) - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) - # Compute dP and dS. - if IS_FP8: - dpT = (tl.dot(v, tl.trans(do)) * descale_v * descale_do) - else: - dpT = tl.dot(v, tl.trans(do)) - if ENABLE_DROPOUT: - dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale - delta_i = Di[None, :] - dsT = pT * (dpT - delta_i) - if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) - else: - dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) - # Increment pointers. - curr_m += step_m - qT_ptrs += step_m * stride_qm - do_ptrs += step_m * stride_dom - return dk, dv - - -# grid = (max_seqlen_k // BLOCK_N, batch, nheads_q) -@triton.jit -def _bwd_kernel_dkdv_causal( - Q, K, V, sm_scale, DO, DK, DV, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - HQ, HK, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - Dropout_mask, dropout_p, philox_seed, philox_offset_base, - Alibi_slopes, - Descale_q, Descale_k, Descale_v, Descale_do, - BLOCK_M: tl.constexpr, # 32 - BLOCK_N: tl.constexpr, # 128 - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - dk = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) - # Figure out causal starting block since we have seqlen_q >=< seqlen_k. - # Unlike forward pass where we tile on M dim and iterate on N dim, so that - # we can skip some M blocks, in backward pass, we tile on the N dim for kv - # and iterate over the M. In this way, we cannot skip N blocks, but only to - # determine the starting M blocks to skip some initial blocks masked by - # causal. - delta_qk = seqlen_q - seqlen_k - if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") - if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") - # q > k: diretcly skip all the way until the start of causal block - start_delta_q_gt_k = delta_qk - # q < k: some blocks will have no Masked block, other needs to re-calc - # starting position - # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the - # masked op - num_blocks_skip = -delta_qk // BLOCK_N - delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk - start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M - if delta_qk >= 0: - start_delta = delta_qk - if DEBUG_TRITON: print(f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}") - else: - start_delta = start_delta_q_lt_k - if DEBUG_TRITON: print(f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}") - # align the delta_qk - start_n = pid * BLOCK_N - - offs_k = tl.arange(0, HEAD_DIM) - offs_n = start_n + tl.arange(0, BLOCK_N) - # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_kv &= mask_k[None, :] - - GROUP_SIZE = HQ // HK - # K/V tensors not changed for the group - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k , mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) - # If MQA / GQA, set the K and V head offsets appropriately. - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - if delta_qk >= 0: - start_m = start_n + start_delta - len_m = BLOCK_N - else: - start_m = max(start_n + delta_qk, 0) - start_m = start_m // BLOCK_M * BLOCK_M - # because we might shift the masked blocks up, we are deeper into - # the masked out region, so we would potentially increase the total - # steps with masked operation to get out of it - residue_m = max(start_n + delta_qk - start_m, 0) - len_m = BLOCK_N + residue_m - if DEBUG_TRITON: print(f"residue_m = {residue_m}") - - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - Q_ptr = Q + adj_q - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - DO_ptr = DO + adj_do - adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - M_ptr = M + adj_delta - Delta_ptr = Delta + adj_delta - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = Dropout_mask + bid * stride_dropoutb + \ - hqid * stride_dropouth - - MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR - # bound the masked operation to q len so it does not have to wast cycles - len_m = min(len_m, seqlen_q) - num_steps = tl.cdiv(len_m, MASK_BLOCK_M) - # when q < k, we may skip the initial masked op - if pid < num_blocks_skip: - num_steps = 0 - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # if start_m is negative, the current N-tile has no block on the - # diagonal of causal mask, so everything have no causal mask - if DEBUG_TRITON: print(f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}") - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o - stride_dropoutm, stride_dropoutn, # strides for dropout - stride_deltam, - MASK_BLOCK_M, BLOCK_N, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - alibi_slope, - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK=True, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - start_m += num_steps * MASK_BLOCK_M - num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) - end_m = start_m + num_steps * BLOCK_M - - if DEBUG_TRITON: print(f"start_m after Masked step: {start_m}; num_steps: {num_steps}") # noqa: E701 - if DEBUG_TRITON: print(f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}") # noqa: E701 - if DEBUG_TRITON: print("unMasked") # noqa: E701 - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o - stride_dropoutm, stride_dropoutn, # strides for dropout - stride_deltam, - BLOCK_M, BLOCK_N, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - alibi_slope, - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - # Write back dV and dK. - adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn - offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk - tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) - - -# the main inner-loop logic for computing dQ -@triton.jit -def _bwd_dq_inner( - dq, # output - q, K, V, do, m, Delta, sm_scale, # input - # shared by Q/K/V. - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, # stride for dropout - stride_deltam, - seqlen_q, seqlen_k, # - BLOCK_M2: tl.constexpr, # - BLOCK_N2: tl.constexpr, # - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, # - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - # Filled in by the wrapper. - start_m, start_n, end_n, num_steps, # - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - delta_qk = seqlen_q - seqlen_k - offs_m = start_m + tl.arange(0, BLOCK_M2) - offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, HEAD_DIM) - - # mask to make sure not OOB of seqlen_q - mask_m = offs_m < seqlen_q - - kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk - vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) - # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. - tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - curr_n = start_n - step_n = BLOCK_N2 - curr_philox_offset = batch_philox_offset - curr_dropout_offset = dropout_offset - RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) - for blk_idx in range(num_steps): - if DEBUG_TRITON: print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 - offs_n = curr_n + tl.arange(0, BLOCK_N2) - # end_n is needed because the end of causal True might not be perfectly - # aligned with the end of the block - mask_n = offs_n < end_n - if DEBUG_TRITON_DETAIL: print(f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}") # noqa: E701 - if DEBUG_TRITON_DETAIL: print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 - mask_kT = mask_n[None, :] - mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) - if PADDED_HEAD: - mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM - - kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) - vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) - - if ENABLE_DROPOUT: - # NOTE: dropout is transposed because it is used to mask pT - philox_offs = curr_philox_offset + \ - offs_m[:, None] * stride_dropoutm + \ - offs_n[None, :] * stride_dropoutn - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = offs_m[:, None] * stride_dropoutm + \ - offs_n[None, :] * stride_dropoutn - dropout_mask = tl.load( - curr_dropout_offset + dropout_offs, - mask=mask_mn) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p - dropout_scale = 1 / (1 - dropout_p) - - if IS_FP8: - qk = (tl.dot(q, kT) * descale_q * descale_k) - else: - qk = tl.dot(q, kT) - qk_scaled = qk * sm_scale - - if USE_ALIBI: - relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - qk_scaled += alibi_block - - if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk_scaled) # noqa: E701 - if USE_EXP2: - p = tl.math.exp2(qk_scaled * RCP_LN2 - m * RCP_LN2) - else: - p = tl.math.exp(qk_scaled - m) - - # Autoregressive masking. - if MASK: - causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] - mask = causal_mask & mask_mn - p = tl.where(mask, p, 0.0) - # Compute dP and dS. - if IS_FP8: - dp = (tl.dot(do, vT) * descale_do * descale_v) - else: - dp = tl.dot(do, vT) - if ENABLE_DROPOUT: - dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale - delta_i = Di[:, None] - ds = p * (dp -delta_i) - # Compute dQ. - # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. - if IS_FP8: - scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) - dq += (tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) * descale_ds * descale_k) - else: - dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) - # Increment pointers. - curr_n += step_n - kT_ptrs += step_n * stride_kn - vT_ptrs += step_n * stride_vn - return dq - - -# grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) -@triton.jit -def _bwd_kernel_dq_causal( - Q, K, V, sm_scale, DO, DQ, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - HQ, HK, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - Dropout_mask, dropout_p, philox_seed, philox_offset_base, - Alibi_slopes, - Descale_q, Descale_k, Descale_v, Descale_do, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - # Figure out causal starting block since we have seqlen_q <=> seqlen_k. - # Unlike forward pass where we tile on M dim and iterate on N dim, so that - # we can skip some M blocks, in backward pass, we tile on the N dim for kv - # and iterate over the M. In this way, we cannot skip N blocks, but only to - # determine the starting M blocks to skip some initial blocks masked by - # causal. - # DQ tiles on M dim and iterate on N dim, so we there could be some tiles we - # can simply skip and we need to adjust starting position. - start_m = pid * BLOCK_M - # seqlen_q > seqlen_k, no need to process these tile for dq - delta_qk = seqlen_q - seqlen_k - if DEBUG_TRITON: print(f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M} = {start_m + BLOCK_M}") # noqa: E701 - if start_m + BLOCK_M < delta_qk: - if DEBUG_TRITON: print(f"start_m + BLOCK_M = {start_m} + {BLOCK_M} = {start_m + BLOCK_M} < delta_qk of {delta_qk}") # noqa: E701 - return - - offs_k = tl.arange(0, HEAD_DIM) - offs_m = start_m + tl.arange(0, BLOCK_M) - # Mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk - offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn - K += adj_k - V += adj_v - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE = HQ // HK - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front - # for every M-tile - end_n = start_m + BLOCK_M - delta_qk - # clamp end_n at [0, seqlen_k] - end_n = max(min(end_n, seqlen_k), 0) - if DEBUG_TRITON: print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - adj_delta = \ - bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - Delta_ptr = Delta + adj_delta - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + \ - bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = \ - Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - - q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(M + adj_delta + offs_m * stride_deltam, - mask=offs_m < seqlen_q) - m = m[:, None] - - MASK_BLOCK_N: tl.constexpr = BLOCK_N // BLK_SLICE_FACTOR - # start can only be 0 at minimum - start_n = max(end_n - BLOCK_M, 0) - num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N) - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) - if DEBUG_TRITON: print(f"pid: {pid}; end_n: {end_n}, start_m: {start_m}") # noqa: E701 - # Compute dQ for masked (diagonal) blocks. - # NOTE: This code scans each row of QK^T backward (from right to left, - # but inside each call to _bwd_dq_inner, from left to right), but that's - # not due to anything important. I just wanted to reuse the loop - # structure for dK & dV above as much as possible. - if DEBUG_TRITON: print(f"Masked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 - dq = _bwd_dq_inner( - dq, - q, K, V, do, m, Delta_ptr, sm_scale, - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, - stride_deltam, - seqlen_q, seqlen_k, - BLOCK_M, MASK_BLOCK_N, - HEAD_DIM, ACTUAL_HEAD_DIM, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - MASK=True, - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - end_n -= num_steps * MASK_BLOCK_N - num_steps = tl.cdiv(end_n, BLOCK_N) - start_n = max(end_n - num_steps * BLOCK_N, 0) - if DEBUG_TRITON: print(f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 - dq = _bwd_dq_inner( - dq, - q, K, V, do, m, Delta_ptr, sm_scale, - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, - stride_deltam, - seqlen_q, seqlen_k, - BLOCK_M, BLOCK_N, - HEAD_DIM, ACTUAL_HEAD_DIM, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - # Write back dQ. - adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk - dq *= sm_scale - tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) - - -@triton.jit -def _bwd_kernel_dkdv_noncausal( - Q, K, V, sm_scale, DO, DK, DV, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - HQ, HK, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - Dropout_mask, dropout_p, philox_seed, philox_offset_base, - Alibi_slopes, - Descale_q, Descale_k, Descale_v, Descale_do, - BLOCK_M: tl.constexpr, # 32 - BLOCK_N: tl.constexpr, # 128 - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - dk = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) - - start_n = pid * BLOCK_N - - offs_k = tl.arange(0, HEAD_DIM) - offs_n = start_n + tl.arange(0, BLOCK_N) - # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_kv &= mask_k[None, :] - - GROUP_SIZE = HQ // HK - # K/V tensors not changed for the group - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk - # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) - # If MQA / GQA, set the K and V head offsets appropriately. - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - Q_ptr = Q + adj_q - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - DO_ptr = DO + adj_do - adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - M_ptr = M + adj_delta - Delta_ptr = Delta + adj_delta - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = Dropout_mask + bid * stride_dropoutb + \ - hqid * stride_dropouth - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # because there is no causal, we always start from the beginning - start_m = 0 - num_steps = tl.cdiv(seqlen_q, BLOCK_M) - dk, dv = _bwd_dkdv_inner( - dk, dv, # output tensors - Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors - stride_qm, stride_qk, # strides for q - stride_dom, stride_dok, # strides for o - stride_dropoutm, stride_dropoutn, # strides for dropout - stride_deltam, - BLOCK_M, BLOCK_N, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim - dropout_p, philox_seed, batch_philox_offset, dropout_offset, # - alibi_slope, - seqlen_q, seqlen_k, # max sequence length for q and k - start_n, start_m, num_steps, # iteration numbers - descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user - MASK=False, # causal masking - ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - # Write back dV and dK. - adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn - offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk - tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) - dk *= sm_scale - tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) - - -@triton.jit -def _bwd_kernel_dq_noncausal( - Q, K, V, sm_scale, DO, DQ, - M, Delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - HQ, HK, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - Dropout_mask, dropout_p, philox_seed, philox_offset_base, - Alibi_slopes, - Descale_q, Descale_k, Descale_v, Descale_do, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, - DEBUG_TRITON: tl.constexpr, - DEBUG_TRITON_DETAIL: tl.constexpr, -): - # program ids - pid = tl.program_id(0) - bid = tl.program_id(1) - hkid = tl.program_id(2) - # figure out varlen start and end - q_start = 0 - k_start = 0 - seqlen_q = max_seqlen_q - seqlen_k = max_seqlen_k - if IS_VARLEN: - # Compute actual sequence lengths - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - k_start = tl.load(cu_seqlens_k + bid) - k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start - - start_m = pid * BLOCK_M - - offs_k = tl.arange(0, HEAD_DIM) - offs_m = start_m + tl.arange(0, BLOCK_M) - # Mask for loading K and V - mask_q = offs_m[:, None] < seqlen_q - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_k = offs_k < ACTUAL_HEAD_DIM - mask_q &= mask_k[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk - offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn - K += adj_k - V += adj_v - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE = HQ // HK - for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): - # offset input and output tensor by batch and Q/K heads - adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm - adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom - adj_delta = \ - bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam - Delta_ptr = Delta + adj_delta - - if USE_ALIBI: - alibi_offset = bid * stride_az + hqid * stride_ah - alibi_slope = tl.load(Alibi_slopes + alibi_offset) - else: - alibi_slope = None - - # batch_philox_offset is the ACTUALLY dropout offset - # dropout_offset is for debug purpose and will be removed later - batch_philox_offset = 0 - dropout_offset = 0 - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base + \ - bid * stride_dropoutb + \ - hqid * stride_dropouth - dropout_offset = \ - Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth - - q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) - m = tl.load(M + adj_delta + offs_m * stride_deltam, - mask=offs_m < seqlen_q) - m = m[:, None] - - if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) - descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) - descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) - else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - - # start can only be 0 at minimum - start_n = 0 - end_n = seqlen_k - num_steps = tl.cdiv(seqlen_k, BLOCK_N) - dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, - q, K, V, do, m, Delta_ptr, sm_scale, - stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, - stride_dropoutm, stride_dropoutn, - stride_deltam, - seqlen_q, seqlen_k, - BLOCK_M, BLOCK_N, - HEAD_DIM, ACTUAL_HEAD_DIM, - dropout_p, philox_seed, batch_philox_offset, dropout_offset, - alibi_slope, - start_m, start_n, end_n, num_steps, - descale_q, descale_k, descale_v, descale_do, - MASK=False, - ENABLE_DROPOUT=ENABLE_DROPOUT, - USE_ALIBI=USE_ALIBI, - USE_EXP2=USE_EXP2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - # Write back dQ. - adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk - dq *= sm_scale - tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) - - -def attention_prefill_backward_triton_split_impl( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - layout: Literal["bshd", "bhsd", "thd"], - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - use_exp2: bool, - # fp8 - descale_q: Optional[torch.Tensor], - descale_k: Optional[torch.Tensor], - descale_v: Optional[torch.Tensor], - descale_o: Optional[torch.Tensor], - descale_do: Optional[torch.Tensor], - descale_dq: Optional[torch.Tensor], - descale_dk: Optional[torch.Tensor], - descale_dv: Optional[torch.Tensor], - # seqused for FA v3 (currently ignored in this implementation) - seqused_q: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, -): - # debug - DEBUG_TRITON: bool = False - DEBUG_TRITON_DETAIL: bool = False - - # fp8 - IS_FP8 = is_fp8(q) - if IS_FP8: - FP8_MAX = torch.finfo(q.dtype).max - # assert that the main inputs are fp8 - assert is_fp8(do) and is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: do.dtype={do.dtype}, q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8." - if is_fp8(o): - FP8_OUTPUT = True - assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." - assert descale_dq is not None, f"descale_dq is None. In fp8, you need to pass a tensor for descale_dq along with a tensor dq." - assert descale_dk is not None, f"descale_dk is None. In fp8, you need to pass a tensor for descale_dk along with a tensor dk." - assert descale_dv is not None, f"descale_dv is None. In fp8, you need to pass a tensor for descale_dv along with a tensor dv." - else: - FP8_OUTPUT = False - - stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None - stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None - stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None - stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None - stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None - - if DEBUG: - print(f"FP8 path triggered in bwd_prefill_split.py (FP8_OUTPUT={FP8_OUTPUT})") - else: - FP8_MAX = None - FP8_OUTPUT = False - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = stride_descale_do_z = None - - - # get strides and shape - batch, nheads_q, nheads_k, head_size, max_seqlen_q_final, max_seqlen_k_final = \ - get_shapes_from_layout( - q, k, layout, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k - ) - q_strides, k_strides, v_strides, o_strides = \ - get_strides_from_layout(q, k, v, o, layout) - stride_qb, stride_qh, stride_qm, stride_qk = q_strides - stride_kb, stride_kh, stride_kn, stride_kk = k_strides - stride_vb, stride_vh, stride_vn, stride_vk = v_strides - stride_ob, stride_oh, stride_om, stride_ok = o_strides - dq_strides, dk_strides, dv_strides, do_strides = \ - get_strides_from_layout(dq, dk, dv, do, layout) - stride_dqb, stride_dqh, stride_dqm, stride_dqk = dq_strides - stride_dkb, stride_dkh, stride_dkn, stride_dkk = dk_strides - stride_dvb, stride_dvh, stride_dvn, stride_dvk = dv_strides - stride_dob, stride_doh, stride_dom, stride_dok = do_strides - IS_VARLEN = layout == "thd" - use_dropout = (dropout_p > 0.0) - use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) - - # get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() - padded_d_model = max(padded_d_model, 32) # NOTE: the causal path expects a min of 32. It will cause a compiler assert. - HEAD_DIM = padded_d_model - ACTUAL_HEAD_DIM = head_size - # meta-parameters - # TODO: fix num_stages later - NUM_WARPS, NUM_STAGES = 4, 1 - WAVES_PER_EU = 1 - PRE_BLOCK = 128 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLK_SLICE_FACTOR = 2 - - # init delta - delta = torch.zeros_like(softmax_lse) - if IS_VARLEN: - stride_deltab = 0 - stride_deltah, stride_deltam = delta.stride() - else: - stride_deltab, stride_deltah, stride_deltam = delta.stride() - pre_grid = (triton.cdiv(max_seqlen_q_final, PRE_BLOCK), batch, nheads_q) - _bwd_preprocess[pre_grid]( - o, do, - delta, - stride_ob, stride_oh, stride_om, stride_ok, - stride_deltab, stride_deltah, stride_deltam, - stride_descale_do_z, - cu_seqlens_q, max_seqlen_q_final, - descale_do, - BLOCK_M=PRE_BLOCK, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8 - ) - - if DEBUG: - print("delta:", delta, delta.shape) - - # dropout mask tensor for debugging. We dump the dropout mask created in - # the kernel for testing - dropout_mask = None - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ - (0, 0 , 0 , 0) - if use_dropout: - dropout_mask = torch.zeros( - (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), - device=q.device, - dtype=torch.float32 - ) - - if DROPOUT_USE_PYTORCH: - if not IS_VARLEN: - dropout_mask = create_dropout_mask( - dropout_p, - (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), - seed = philox_seed - ) - else: - dropout_mask = create_dropout_mask_varlen( - dropout_p, batch, nheads_q, - cu_seqlens_q, cu_seqlens_k, philox_seed - ) - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ - dropout_mask.stride() - - grid_dkdv = ((max_seqlen_k_final + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) - grid_dq = ((max_seqlen_q_final + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) - if causal: - if DEBUG_TRITON: print(f"_bwd_kernel_dkdv: grid = {grid_dkdv}, block_size = ({BLOCK_M1, BLOCK_N1})", ) # noqa: E701 - _bwd_kernel_dkdv_causal[grid_dkdv]( - q, k, v, sm_scale, do, dk, dv, - softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - nheads_q, nheads_k, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, - dropout_mask, dropout_p, philox_seed, philox_offset, - alibi_slopes, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M1, BLOCK_N1, BLK_SLICE_FACTOR, - HEAD_DIM, ACTUAL_HEAD_DIM, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu = WAVES_PER_EU, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - if DEBUG_TRITON: print(f"\n_bwd_kernel_dq: grid = {grid_dq}, block_size = ({BLOCK_M2, BLOCK_N2})", ) # noqa: E701 - _bwd_kernel_dq_causal[grid_dq]( - q, k, v, sm_scale, do, dq, - softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - nheads_q, nheads_k, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, - dropout_mask, dropout_p, philox_seed, philox_offset, - alibi_slopes, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, - HEAD_DIM, ACTUAL_HEAD_DIM, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu = WAVES_PER_EU, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - else: - _bwd_kernel_dkdv_noncausal[grid_dkdv]( - q, k, v, sm_scale, do, dk, dv, - softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dkb, stride_dkh, stride_dkn, stride_dkk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - nheads_q, nheads_k, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, - dropout_mask, dropout_p, philox_seed, philox_offset, - alibi_slopes, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M1, BLOCK_N1, BLK_SLICE_FACTOR, - HEAD_DIM, ACTUAL_HEAD_DIM, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu = WAVES_PER_EU, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - _bwd_kernel_dq_noncausal[grid_dq]( - q, k, v, sm_scale, do, dq, - softmax_lse, delta, - stride_qb, stride_qh, stride_qm, stride_qk, - stride_kb, stride_kh, stride_kn, stride_kk, - stride_vb, stride_vh, stride_vn, stride_vk, - stride_dqb, stride_dqh, stride_dqm, stride_dqk, - stride_deltab, stride_deltah, stride_deltam, - stride_dob, stride_doh, stride_dom, stride_dok, - stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, - stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, - stride_az, stride_ah, - nheads_q, nheads_k, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q_final, max_seqlen_k_final, - dropout_mask, dropout_p, philox_seed, philox_offset, - alibi_slopes, - descale_q, descale_k, descale_v, descale_do, - BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, - HEAD_DIM, ACTUAL_HEAD_DIM, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu = WAVES_PER_EU, - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - return delta diff --git a/flash_attn/flash_attn_triton_amd/bwd_ref.py b/flash_attn/flash_attn_triton_amd/bwd_ref.py deleted file mode 100644 index cb1637157a0..00000000000 --- a/flash_attn/flash_attn_triton_amd/bwd_ref.py +++ /dev/null @@ -1,545 +0,0 @@ -import torch -import math -from typing import Literal, Optional -from .utils import compute_alibi_tensor_ref - -DEBUG = False -DEBUG_CORE = False - -def attention_backward_core_ref_impl( - do, q, k, v, o, softmax_lse, sm_scale, causal, window_size_left, window_size_right, - dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2 -): - if DEBUG_CORE: - print() - print("attention_backward_core_ref_impl") - print("do:", do, do.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("o:", o, o.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("sm_scale:", sm_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("dropout_p:", dropout_p) - print("philox_seed:", philox_seed) - print("philox_offset:", philox_offset) - print("use_exp2:", use_exp2) - - # cast to float32 - do = do.to(torch.float32) - q = q.to(torch.float32) - k = k.to(torch.float32) - v = v.to(torch.float32) - o = o.to(torch.float32) - softmax_lse = softmax_lse.to(torch.float32) - - # recompute attention_scores. Make sure it matches the forward impl. i.e. It use float32 - attention_scores = torch.matmul(q, k.transpose(-2, -1)) - if DEBUG_CORE: - print("attention_scores:", attention_scores, attention_scores.shape) - - # scale scores - attention_scaled_scores = sm_scale * attention_scores - if DEBUG_CORE: - print("attention_scaled_scores:", attention_scaled_scores, attention_scaled_scores.shape) - - if alibi_slopes is not None: - L_q, L_k = q.shape[1], k.shape[1] - if DEBUG_CORE: - print("alibi_slopes:", alibi_slopes, alibi_slopes.shape) - alibi_bias = compute_alibi_tensor_ref(alibi_slopes, L_q, L_k) - alibi_bias = alibi_bias.reshape(-1, L_q, L_k) - if DEBUG_CORE: - print("alibi_bias:", alibi_bias, alibi_bias.shape) - attention_scaled_scores = attention_scaled_scores + alibi_bias - if DEBUG_CORE: - print("attention_scaled_scores after alibi:", attention_scaled_scores, attention_scaled_scores.shape) - - # Apply masks - L_q, L_k = q.shape[1], k.shape[1] - row_idx = torch.arange(L_q, device=q.device).unsqueeze(1) - col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) - col_offset = L_k - L_q - - mask_applied = False - if causal and (window_size_left, window_size_right) == (-1, -1): - # Pure causal: ensure query doesn't attend to future keys - mask = row_idx >= (col_idx - col_offset) - mask_applied = True - if DEBUG_CORE: - print("causal_mask:", mask) - elif (window_size_left, window_size_right) != (-1, -1): - # Handle the case where window sizes exceed sequence length - if window_size_left >= L_k: - window_size_left = -1 # No left limit - if window_size_right >= L_k: - window_size_right = -1 # No right limit - - if causal: - # Causal + sliding window: ensure we don't attend to future - window_size_right = min(window_size_right, 0) if window_size_right != -1 else 0 - - # Create sliding window mask - # Each query at position i attends to keys in [i + offset - left, i + offset + right] - if window_size_left == -1 and window_size_right == -1: - # No window restriction - mask = torch.ones((L_q, L_k), dtype=torch.bool, device=q.device) - else: - mask = torch.ones((L_q, L_k), dtype=torch.bool, device=q.device) - if window_size_left != -1: - # Each query at position i attends to keys from position (i - left) accounting for offset - mask = mask & (col_idx >= (row_idx + col_offset - window_size_left)) - if window_size_right != -1: - # Each query at position i attends to keys up to position (i + right) accounting for offset - mask = mask & (col_idx <= (row_idx + col_offset + window_size_right)) - - # Apply causal constraint - if causal: - causal_mask = row_idx >= (col_idx - col_offset) - mask = mask & causal_mask - - mask_applied = True - if DEBUG_CORE: - print(f"sliding_window_mask (left={window_size_left}, right={window_size_right}):", mask) - - # Apply the mask if created - if mask_applied: - attention_scaled_scores = attention_scaled_scores.masked_fill( - torch.logical_not(mask.unsqueeze(0)), float('-inf') - ) - if DEBUG_CORE: - print("attention_scaled_scores after masking:", attention_scaled_scores, attention_scaled_scores.shape) - - # compute probabilities using softmax_lse - if use_exp2: - RCP_LN = 1 / math.log(2) - attention_scaled_scores_base2 = attention_scaled_scores * RCP_LN - softmax_lse_base2 = softmax_lse * RCP_LN - softmax_lse_3d = softmax_lse_base2.unsqueeze(-1) - p = torch.exp2(attention_scaled_scores_base2 - softmax_lse_3d) - else: - softmax_lse_3d = softmax_lse.unsqueeze(-1) - p = torch.exp(attention_scaled_scores - softmax_lse_3d) - - # Zero out positions outside the mask - if mask_applied: - p = p.masked_fill(torch.logical_not(mask.unsqueeze(0)), 0.0) - - if DEBUG_CORE: - print("softmax_lse_3d:", softmax_lse_3d, softmax_lse_3d.shape) - print("p:", p, p.shape) - - if dropout_p > 0.0: - rand_vals = torch.rand(p.shape, generator=torch.Generator(device=p.device).manual_seed(philox_seed), device=p.device, dtype=p.dtype) - dropout_mask, dropout_scale = rand_vals > dropout_p, (1.0 / (1 - dropout_p)) - if DEBUG_CORE: - print("dropout_scale:", dropout_scale) - print("dropout_mask:", dropout_mask) - - p_drop = torch.where(dropout_mask, p, torch.zeros_like(p)) - p_drop_scaled = p_drop * dropout_scale - if DEBUG_CORE: - print("dropout_scale:", dropout_scale) - print("p_drop:", p_drop, p_drop.shape) - print("p_drop_scaled:", p_drop_scaled, p_drop_scaled.shape) - - # compute dv - dv = torch.matmul(p_drop_scaled.transpose(-2, -1), do) - if DEBUG_CORE: - print("dv:", dv, dv.shape) - - # compute dp - dp_dropout = torch.matmul(do, v.transpose(-2, -1)) - dp = torch.where(dropout_mask, dp_dropout, torch.zeros_like(dp_dropout)) * dropout_scale - if DEBUG_CORE: - print("dp_dropout:", dp_dropout, dp_dropout.shape) - print("dp:", dp, dp.shape) - else: - # compute dv - dv = torch.matmul(p.transpose(-2, -1), do) - if DEBUG_CORE: - print("dv:", dv, dv.shape) - - # compute dp - dp = torch.matmul(do, v.transpose(-2, -1)) - if DEBUG_CORE: - print("dp:", dp, dp.shape) - - # calculate ds - if True: - delta = torch.sum(o * do, axis=-1).unsqueeze(-1) - else: - delta = torch.sum(p * dp, axis=-1).unsqueeze(-1) - if DEBUG_CORE: - print("delta:", delta, delta.shape) - dscores_scaled = p * (dp - delta) - - # Zero out gradients for positions outside the mask - if mask_applied: - dscores_scaled = dscores_scaled.masked_fill(torch.logical_not(mask.unsqueeze(0)), 0.0) - - ds = dscores_scaled * sm_scale - if DEBUG_CORE: - print("dscores_scaled:", dscores_scaled, dscores_scaled.shape) - print("ds:", ds, ds.shape) - - # compute gradient wrt k & q - dk = torch.matmul(ds.transpose(-2, -1), q) - dq = torch.matmul(ds, k) - if DEBUG_CORE: - print("dk:", dk, dk.shape) - print("dq:", dq, dq.shape) - - # cast back to original dtype - dq = dq.to(torch.float16) - dk = dk.to(torch.float16) - dv = dv.to(torch.float16) - # remove d dim with size 1 - delta = delta.squeeze(-1) - - if DEBUG_CORE: - print("attention_backward_core_ref_impl output") - print("delta:", delta, delta.shape) - print("dv:", dv, dv.shape) - print("dk:", dk, dk.shape) - print("dq:", dq, dq.shape) - - return dq, dk, dv, delta - -def attention_varlen_backward_pytorch_ref_impl( - do, - q, - k, - v, - o, - softmax_lse, - sm_scale, - causal, - window_size_left, - window_size_right, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2, -): - # Ensure the layout is 'thd' - if layout != 'thd': - raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.") - - batch_size = cu_seqlens_q.shape[0] - 1 - nheads_q, head_dim = q.shape[1], q.shape[2] - nheads_k = k.shape[1] - - group_size = nheads_q // nheads_k - if nheads_q % nheads_k != 0: - raise ValueError("nheads_q must be divisible by nheads_k") - - # Pre-allocate outputs - total_L_q = q.shape[0] - total_L_k = k.shape[0] - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - # delta has the same shape as softmax_lse - delta = torch.zeros_like(softmax_lse) - - for i in range(batch_size): - # Get the start and end indices for the current sequence - start_q = cu_seqlens_q[i].item() - end_q = cu_seqlens_q[i + 1].item() - start_k = cu_seqlens_k[i].item() - end_k = cu_seqlens_k[i + 1].item() - - # Extract q_i, k_i, v_i, do_i, o_i, softmax_lse_i - q_i = q[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] - k_i = k[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] - v_i = v[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] - do_i = do[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] - o_i = o[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] - softmax_lse_i = softmax_lse[:, start_q:end_q] # [nheads_q, L_q_i] - - if group_size != 1: - # MQA or GQA case - # Reshape tensors to include group dimension - q_i = q_i.view(q_i.shape[0], nheads_k, group_size, head_dim) - do_i = do_i.view(do_i.shape[0], nheads_k, group_size, head_dim) - o_i = o_i.view(o_i.shape[0], nheads_k, group_size, head_dim) - softmax_lse_i = softmax_lse_i.view(nheads_k, group_size, softmax_lse_i.shape[1]) - # Expand k_i and v_i to match group_size - k_i = k_i.unsqueeze(2).expand(-1, -1, group_size, -1) - v_i = v_i.unsqueeze(2).expand(-1, -1, group_size, -1) - # Flatten the nheads_k and group_size dimensions - q_i = q_i.reshape(q_i.shape[0], nheads_k * group_size, head_dim) - do_i = do_i.reshape(do_i.shape[0], nheads_k * group_size, head_dim) - o_i = o_i.reshape(o_i.shape[0], nheads_k * group_size, head_dim) - softmax_lse_i = softmax_lse_i.reshape(nheads_k * group_size, softmax_lse_i.shape[2]) - k_i = k_i.reshape(k_i.shape[0], nheads_k * group_size, head_dim) - v_i = v_i.reshape(v_i.shape[0], nheads_k * group_size, head_dim) - - # Permute to [nheads_total, L, head_dim] - q_i = q_i.permute(1, 0, 2) - k_i = k_i.permute(1, 0, 2) - v_i = v_i.permute(1, 0, 2) - do_i = do_i.permute(1, 0, 2) - o_i = o_i.permute(1, 0, 2) - - if alibi_slopes is not None: - alibi_slopes_i = alibi_slopes[i] - else: - alibi_slopes_i = None - - # Call the core backward function for this sequence - dq_i, dk_i, dv_i, delta_i = attention_backward_core_ref_impl( - do_i, - q_i, - k_i, - v_i, - o_i, - softmax_lse_i, - sm_scale, - causal, - window_size_left, - window_size_right, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes_i, - use_exp2 - ) - - # Convert back to 'thd' layout - dq_i = dq_i.permute(1, 0, 2) # [L_q_i, nheads_total, head_dim] - dk_i = dk_i.permute(1, 0, 2) # [L_k_i, nheads_total, head_dim] - dv_i = dv_i.permute(1, 0, 2) # [L_k_i, nheads_total, head_dim] - - if group_size != 1: - # Reshape dq_i and delta_i back to original shape - dq_i = dq_i.view(dq_i.shape[0], nheads_k, group_size, head_dim) - L_q_i = delta_i.shape[1] - delta_i = delta_i.view(nheads_k, group_size, L_q_i) - # Sum dk_i and dv_i over group dimension - dk_i = dk_i.view(dk_i.shape[0], nheads_k, group_size, head_dim) - dv_i = dv_i.view(dv_i.shape[0], nheads_k, group_size, head_dim) - dk_i = dk_i.sum(dim=2) - dv_i = dv_i.sum(dim=2) - # Reshape dq_i back to [L_q_i, nheads_q, head_dim] - dq_i = dq_i.reshape(dq_i.shape[0], nheads_q, head_dim) - delta_i = delta_i.reshape(nheads_q, L_q_i) - else: - # No need to reshape - pass - - # Place outputs in pre-allocated tensors - dq[start_q:end_q, :, :] = dq_i - dk[start_k:end_k, :, :] += dk_i # Accumulate gradients for shared keys - dv[start_k:end_k, :, :] += dv_i # Accumulate gradients for shared values - delta[:, start_q:end_q] = delta_i - - return dq, dk, dv, delta - -def attention_vanilla_backward_pytorch_ref_impl( - do, - q, - k, - v, - o, - softmax_lse, - sm_scale, - causal, - window_size_left, - window_size_right, - layout, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2, -): - if layout == "bshd": - if DEBUG: - print() - print("Changing layout to bhsd!") - do = do.transpose(1, 2).contiguous() - q = q.transpose(1, 2).contiguous() - k = k.transpose(1, 2).contiguous() - v = v.transpose(1, 2).contiguous() - o = o.transpose(1, 2).contiguous() - elif layout == "bhsd": - pass - else: - raise ValueError(f"Unknown layout {layout}") - - # Prepare tensors - batch_size, nheads_q, seq_len_q, head_dim = q.shape - batch_size, nheads_k, seq_len_k, head_dim = k.shape - - group_size = nheads_q // nheads_k - if nheads_q % nheads_k != 0: - raise ValueError("nheads_q must be divisible by nheads_k") - - if group_size != 1: - # MQA or GQA case - # Reshape do, q, o to [batch_size, nheads_k, group_size, seq_len_q, head_dim] - do = do.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) - q = q.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) - o = o.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) - # Reshape softmax_lse to [batch_size, nheads_k, group_size, seq_len_q] - softmax_lse = softmax_lse.reshape(batch_size, nheads_k, group_size, seq_len_q) - # Expand k and v to match group_size - k = k.unsqueeze(2).expand(-1, -1, group_size, -1, -1) # [batch_size, nheads_k, group_size, seq_len_k, head_dim] - v = v.unsqueeze(2).expand(-1, -1, group_size, -1, -1) - # Flatten the first three dimensions for computation - do = do.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) - q = q.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) - k = k.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) - v = v.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) - o = o.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) - softmax_lse = softmax_lse.reshape(batch_size * nheads_k * group_size, seq_len_q) - else: - # Standard case - do = do.reshape(batch_size * nheads_q, seq_len_q, head_dim) - q = q.reshape(batch_size * nheads_q, seq_len_q, head_dim) - k = k.reshape(batch_size * nheads_k, seq_len_k, head_dim) - v = v.reshape(batch_size * nheads_k, seq_len_k, head_dim) - o = o.reshape(batch_size * nheads_q, seq_len_q, head_dim) - softmax_lse = softmax_lse.reshape(batch_size * nheads_q, seq_len_q) - - # Call the core backward function - dq, dk, dv, delta = attention_backward_core_ref_impl( - do, - q, - k, - v, - o, - softmax_lse, - sm_scale, - causal, - window_size_left, - window_size_right, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2 - ) - - if group_size != 1: - # Reshape dq back to [batch_size, nheads_k, group_size, seq_len_q, head_dim] - dq = dq.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) - # Reshape delta back to [batch_size, nheads_k, group_size, seq_len_q] - delta = delta.reshape(batch_size, nheads_k, group_size, seq_len_q) - # Sum dk and dv over group_size dimension, since k and v are shared across groups - dk = dk.reshape(batch_size, nheads_k, group_size, seq_len_k, head_dim) - dk = dk.sum(dim=2) # Sum over group_size dimension - dv = dv.reshape(batch_size, nheads_k, group_size, seq_len_k, head_dim) - dv = dv.sum(dim=2) - # Reshape dq to [batch_size, nheads_q, seq_len_q, head_dim] - dq = dq.reshape(batch_size, nheads_k * group_size, seq_len_q, head_dim) - delta = delta.reshape(batch_size, nheads_k * group_size, seq_len_q) - else: - # Standard case - dq = dq.reshape(batch_size, nheads_q, seq_len_q, head_dim) - dk = dk.reshape(batch_size, nheads_k, seq_len_k, head_dim) - dv = dv.reshape(batch_size, nheads_k, seq_len_k, head_dim) - delta = delta.reshape(batch_size, nheads_q, seq_len_q) - - # Go back to original layout - if layout == "bshd": - if DEBUG: - print() - print("Changing back to bshd!") - dq = dq.transpose(1, 2) - dk = dk.transpose(1, 2) - dv = dv.transpose(1, 2) - elif layout == "bhsd": - pass - else: - raise ValueError(f"Unknown layout {layout}") - - return dq, dk, dv, delta - -def attention_backward_pytorch_ref_impl( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - window_size_left: int, - window_size_right: int, - layout: Literal["bshd", "bhsd", "thd"], - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - use_exp2: bool -): - if layout == "thd": - dq_ref, dk_ref, dv_ref, delta = attention_varlen_backward_pytorch_ref_impl( - do, - q, - k, - v, - o, - softmax_lse, - sm_scale, - causal, - window_size_left, - window_size_right, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2, - ) - else: - dq_ref, dk_ref, dv_ref, delta = attention_vanilla_backward_pytorch_ref_impl( - do, - q, - k, - v, - o, - softmax_lse, - sm_scale, - causal, - window_size_left, - window_size_right, - layout, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2, - ) - - - # copy into output tensor - dv.copy_(dv_ref.to(dv.dtype)) - dk.copy_(dk_ref.to(dk.dtype)) - dq.copy_(dq_ref.to(dq.dtype)) - - return delta \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index 327967cebf7..bb7edad3494 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -1,37 +1,81 @@ import torch import triton import triton.language as tl -from typing import Literal, Optional, Union -from .utils import AUTOTUNE, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna, is_fp8 +from typing import Literal, Optional +from .utils import ( + DEBUG, + AUTOTUNE, + get_padded_headsize, + get_shape_and_strides_from_layout, + apply_rotary, + is_cdna, + is_fp8, +) -DEBUG = False def get_cdna_autotune_configs(): return [ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + ], [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL", + "VARLEN", + "HQ", + "HK", + ] + def get_autotune_configs(): if AUTOTUNE: if is_cdna(): autotune_configs, autotune_keys = get_cdna_autotune_configs() - fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys - reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys - return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) + fwd_auto_tune_configs, fwd_autotune_keys = autotune_configs, autotune_keys + reduce_auto_tune_configs, reduce_autotune_keys = ( + autotune_configs, + autotune_keys, + ) + return (fwd_auto_tune_configs, fwd_autotune_keys), ( + reduce_auto_tune_configs, + reduce_autotune_keys, + ) else: raise ValueError("Unknown Device Type") else: @@ -52,19 +96,34 @@ def get_autotune_configs(): "HK", ] - fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys + fwd_auto_tune_configs, fwd_autotune_keys = autotune_configs, autotune_keys reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys - return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) + return (fwd_auto_tune_configs, fwd_autotune_keys), ( + reduce_auto_tune_configs, + reduce_autotune_keys, + ) -(fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) = get_autotune_configs() +(fwd_auto_tune_configs, fwd_autotune_keys), ( + reduce_auto_tune_configs, + reduce_autotune_keys, +) = get_autotune_configs() + @triton.jit def _attn_fwd_inner( - q, kT, v, pos, col_mask, - m_i, l_i, acc, + q, + kT, + v, + pos, + col_mask, + m_i, + l_i, + acc, pid_m, - q_descale, k_descale, v_descale, # FP8 scaling factors + q_descale, + k_descale, + v_descale, # FP8 scaling factors IS_FP8: tl.constexpr, # FP8 flag BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, @@ -81,41 +140,42 @@ def _attn_fwd_inner( # -- compute qk --- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) if IS_FP8: - qk += (tl.dot(q, kT) * q_descale * k_descale) # Apply FP8 scaling + qk += tl.dot(q, kT) * q_descale * k_descale # Apply FP8 scaling else: qk += tl.dot(q, kT) # noqa: F821 if USE_ALIBI: row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) col_idx = pos + tl.arange(0, BLOCK_N) - + # Compute relative positions relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) relative_pos = tl.abs(relative_pos) - + # Compute ALiBi bias alibi_bias = -1 * alibi_slope * relative_pos - qk += (alibi_bias * 1.44269504) + qk += alibi_bias * 1.44269504 # ------------------------------------------------------------------ # masking # ------------------------------------------------------------------ if USE_SLIDING_WINDOW: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # q positions - col_idx = pos + tl.arange(0, BLOCK_N) # k positions - row = row_idx[:, None] # [M,1] - col = col_idx[None, :] # [1,N] + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # q positions + col_idx = pos + tl.arange(0, BLOCK_N) # k positions + row = row_idx[:, None] # [M,1] + col = col_idx[None, :] # [1,N] if IS_CAUSAL: # -------- causal + window -------- - diag = N_CTX_K_FINAL - N_CTX_Q # sk-sq + diag = N_CTX_K_FINAL - N_CTX_Q # sk-sq causal_ok = col <= row + diag - if WINDOW_SIZE_LEFT < 0: # only right window + if WINDOW_SIZE_LEFT < 0: # only right window win_ok = col <= row + diag + WINDOW_SIZE_RIGHT - else: # both sides - win_ok = ((col >= row + diag - WINDOW_SIZE_LEFT) & - (col <= row + diag + WINDOW_SIZE_RIGHT)) - mask = ~(causal_ok & win_ok) # True ⇒ -inf + else: # both sides + win_ok = (col >= row + diag - WINDOW_SIZE_LEFT) & ( + col <= row + diag + WINDOW_SIZE_RIGHT + ) + mask = ~(causal_ok & win_ok) # True ⇒ -inf else: # -------- non-causal window -------- sk, sq = N_CTX_K_FINAL, N_CTX_Q @@ -123,8 +183,8 @@ def _attn_fwd_inner( mask = col > row + (sk - sq) + WINDOW_SIZE_RIGHT else: right = tl.minimum(row + (sk - sq) + WINDOW_SIZE_RIGHT, sk) - left = row + (sk - sq) - WINDOW_SIZE_LEFT - mask = (col > right) | (col < left) + left = row + (sk - sq) - WINDOW_SIZE_LEFT + mask = (col > right) | (col < left) qk = tl.where(mask, float("-inf"), qk) else: if IS_CAUSAL: @@ -144,16 +204,16 @@ def _attn_fwd_inner( # Expect col_mask shape: [BLOCK_N]. True where column is within sequence. qk = tl.where(col_mask[None, :], qk, float("-inf")) - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) # per-row max so far + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) # per-row max so far # rows that are *all* -inf after masking - valid = m_i_new > float("-inf") + valid = m_i_new > float("-inf") # scale previous partial sums safely - alpha = tl.where(valid, tl.math.exp2(m_i - m_i_new), 0.0) + alpha = tl.where(valid, tl.math.exp2(m_i - m_i_new), 0.0) # subtract the row max only on valid rows - qk = tl.where(valid[:, None], qk - m_i_new[:, None], float("-inf")) + qk = tl.where(valid[:, None], qk - m_i_new[:, None], float("-inf")) p = tl.math.exp2(qk) # -- update m_i and l_i -- @@ -167,7 +227,7 @@ def _attn_fwd_inner( acc += tl.dot(p.to(v.dtype), v) * v_descale # Apply FP8 scaling for V else: acc += tl.dot(p.to(v.dtype), v) - + return m_i, l_i, acc @@ -228,7 +288,7 @@ def _fwd_kernel_splitK( stride_vn_d, stride_bt_b, stride_bt_s, - stride_az, + stride_az, stride_ah, stride_q_descale_z, # FP8 descale strides stride_q_descale_h, @@ -286,12 +346,20 @@ def _fwd_kernel_splitK( if IS_FP8: if IS_GQA: # For MQA/GQA, q_descale uses the same indexing as k/v (hk_id) - q_descale = tl.load(Q_Descale + z_id * stride_q_descale_z + hk_id * stride_q_descale_h) + q_descale = tl.load( + Q_Descale + z_id * stride_q_descale_z + hk_id * stride_q_descale_h + ) else: # For MHA, q_descale uses hq_id - q_descale = tl.load(Q_Descale + z_id * stride_q_descale_z + hq_id * stride_q_descale_h) - k_descale = tl.load(K_Descale + z_id * stride_k_descale_z + hk_id * stride_k_descale_h) - v_descale = tl.load(V_Descale + z_id * stride_v_descale_z + hv_id * stride_v_descale_h) + q_descale = tl.load( + Q_Descale + z_id * stride_q_descale_z + hq_id * stride_q_descale_h + ) + k_descale = tl.load( + K_Descale + z_id * stride_k_descale_z + hk_id * stride_k_descale_h + ) + v_descale = tl.load( + V_Descale + z_id * stride_v_descale_z + hv_id * stride_v_descale_h + ) else: q_descale, k_descale, v_descale = 1.0, 1.0, 1.0 @@ -318,21 +386,29 @@ def _fwd_kernel_splitK( # compute ptrs q_offset = Q + hq_id * stride_qh + z_id * stride_qz + g_id * stride_qg q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - + # Handle block table for paged attention if USE_BLOCK_TABLE: # K and V now point to paged cache # Each batch has its own block table row block_table_ptr = Block_table + z_id * stride_bt_b else: - k_offset = K + hk_id * stride_kh + cache_batch_idx * stride_kz + g_id * stride_kg - v_offset = V + hv_id * stride_vh + cache_batch_idx * stride_vz + g_id * stride_vg + k_offset = ( + K + hk_id * stride_kh + cache_batch_idx * stride_kz + g_id * stride_kg + ) + v_offset = ( + V + hv_id * stride_vh + cache_batch_idx * stride_vz + g_id * stride_vg + ) # compute masks if PADDED_HEAD: q_mask = (offs_m < N_CTX_Q)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] - kT_mask = (offs_d < ACTUAL_BLOCK_DMODEL)[:, None] & (offs_n < N_CTX_K_FINAL)[None, :] - v_mask = (offs_n < N_CTX_K_FINAL)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] + kT_mask = (offs_d < ACTUAL_BLOCK_DMODEL)[:, None] & (offs_n < N_CTX_K_FINAL)[ + None, : + ] + v_mask = (offs_n < N_CTX_K_FINAL)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[ + None, : + ] osk_mask = (offs_m < N_CTX_Q)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] else: q_mask = (offs_m < N_CTX_Q)[:, None] @@ -344,7 +420,7 @@ def _fwd_kernel_splitK( # 2^x instead of exp in the loop because CSE and LICM # don't work as expected with `exp` in the loop qk_scale = sm_scale * 1.44269504 - + # load q: it will stay in SRAM throughout q = tl.load(q_ptrs, mask=q_mask, other=0.0) q = (q * qk_scale).to(q.dtype) @@ -366,22 +442,22 @@ def _fwd_kernel_splitK( # Paged attention: process all KV blocks from cache # Note: Cache should be updated externally before calling this kernel num_kv_blocks = (N_CTX_K_FINAL + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K - + for block_idx in range(num_kv_blocks): # Calculate sequence range for this block block_start = block_idx * BLOCK_SIZE_K block_end = tl.minimum(block_start + BLOCK_SIZE_K, N_CTX_K_FINAL) - + # Check if block overlaps with our split-k range [lo, hi) if block_end > lo and block_start < hi: # Load physical block number physical_block = tl.load(block_table_ptr + block_idx * stride_bt_s) - + # Calculate the range within this block that overlaps with [lo, hi) process_start = tl.maximum(lo - block_start, 0) process_end = tl.minimum(hi - block_start, BLOCK_SIZE_K) process_end = tl.minimum(process_end, block_end - block_start) - + # Instead of forcing a floor alignment to BLOCK_N (which can still skip # part of the intended range if start falls mid-tile for small splits), # start from the raw (possibly unaligned) process_start rounded *down* but @@ -395,51 +471,82 @@ def _fwd_kernel_splitK( process_start = aligned_start else: process_start = aligned_start - + for offset in range(process_start, process_end, BLOCK_N): # Current position (may begin slightly before logical split range; masking fixes it) pos = block_start + offset # Proceed unconditionally; masking below enforces [lo, hi) # Calculate base addresses for K and V in this physical block - k_base = K + physical_block * BLOCK_SIZE_K * stride_kn + hk_id * stride_kh + g_id * stride_kg - v_base = V + physical_block * BLOCK_SIZE_K * stride_vn + hv_id * stride_vh + g_id * stride_vg - + k_base = ( + K + + physical_block * BLOCK_SIZE_K * stride_kn + + hk_id * stride_kh + + g_id * stride_kg + ) + v_base = ( + V + + physical_block * BLOCK_SIZE_K * stride_vn + + hv_id * stride_vh + + g_id * stride_vg + ) + # Offsets within the current block block_offs = offset + offs_n - + # Masks for valid data respecting: # (1) global key length (seq_mask) # (2) block bounds (block_mask) # (3) current split range [lo, hi) - seq_mask = ((pos + offs_n) < N_CTX_K_FINAL) - block_mask = (block_offs < BLOCK_SIZE_K) - end_mask = (block_offs < process_end) + seq_mask = (pos + offs_n) < N_CTX_K_FINAL + block_mask = block_offs < BLOCK_SIZE_K + end_mask = block_offs < process_end split_mask = ((pos + offs_n) >= lo) & ((pos + offs_n) < hi) col_mask = seq_mask & block_mask & end_mask & split_mask - + # Apply masks kT_mask_final = kT_mask & col_mask[None, :] v_mask_final = v_mask & col_mask[:, None] - + # Load K and V - kT_ptrs = k_base + offs_d[:, None] * stride_kd + block_offs[None, :] * stride_kn - v_ptrs = v_base + block_offs[:, None] * stride_vn + offs_d[None, :] * stride_vd - + kT_ptrs = ( + k_base + + offs_d[:, None] * stride_kd + + block_offs[None, :] * stride_kn + ) + v_ptrs = ( + v_base + + block_offs[:, None] * stride_vn + + offs_d[None, :] * stride_vd + ) + kT = tl.load(kT_ptrs, mask=kT_mask_final, other=0.0) v = tl.load(v_ptrs, mask=v_mask_final, other=0.0) - + # Unified inner function handles both paged and contiguous m_i, l_i, acc = _attn_fwd_inner( - q, kT, v, pos, col_mask, - m_i, l_i, acc, + q, + kT, + v, + pos, + col_mask, + m_i, + l_i, + acc, pid_m, - q_descale, k_descale, v_descale, + q_descale, + k_descale, + v_descale, IS_FP8, - BLOCK_M, BLOCK_N, - N_CTX_Q, N_CTX_K_FINAL, - USE_ALIBI, alibi_slope, - USE_SLIDING_WINDOW, IS_CAUSAL, - WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, + BLOCK_M, + BLOCK_N, + N_CTX_Q, + N_CTX_K_FINAL, + USE_ALIBI, + alibi_slope, + USE_SLIDING_WINDOW, + IS_CAUSAL, + WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT, True, ) else: @@ -447,8 +554,16 @@ def _fwd_kernel_splitK( # Note: Cache should be updated externally before calling this kernel # loop over k, v and update accumulator for start_n in range(lo, hi, BLOCK_N): - kT_ptrs = k_offset + offs_d[:, None] * stride_kd + (start_n + offs_n)[None, :] * stride_kn - V_ptrs = v_offset + (start_n + offs_n)[:, None] * stride_vn + offs_d[None, :] * stride_vd + kT_ptrs = ( + k_offset + + offs_d[:, None] * stride_kd + + (start_n + offs_n)[None, :] * stride_kn + ) + V_ptrs = ( + v_offset + + (start_n + offs_n)[:, None] * stride_vn + + offs_d[None, :] * stride_vd + ) # load k kT = tl.load(kT_ptrs, mask=kT_mask, other=0.0) @@ -460,22 +575,37 @@ def _fwd_kernel_splitK( col_valid_mask = offs_n < (hi - start_n) m_i, l_i, acc = _attn_fwd_inner( - q, kT, v, start_n, col_valid_mask, - m_i, l_i, acc, + q, + kT, + v, + start_n, + col_valid_mask, + m_i, + l_i, + acc, pid_m, - q_descale, k_descale, v_descale, + q_descale, + k_descale, + v_descale, IS_FP8, - BLOCK_M, BLOCK_N, - N_CTX_Q, N_CTX_K_FINAL, - USE_ALIBI, alibi_slope, - USE_SLIDING_WINDOW, IS_CAUSAL, - WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, + BLOCK_M, + BLOCK_N, + N_CTX_Q, + N_CTX_K_FINAL, + USE_ALIBI, + alibi_slope, + USE_SLIDING_WINDOW, + IS_CAUSAL, + WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT, BOUNDS_CHECKS_N, ) # write back O osk_offset = Out_splitK + pid_zhg * stride_osk_zhg + pid_splitk * stride_osk_s - osk_ptrs = osk_offset + offs_m[:, None] * stride_osk_m + offs_d[None, :] * stride_osk_d + osk_ptrs = ( + osk_offset + offs_m[:, None] * stride_osk_m + offs_d[None, :] * stride_osk_d + ) tl.store( osk_ptrs, acc, @@ -534,7 +664,6 @@ def _splitK_reduce( offs_splitK = tl.arange(0, splitK_pow2) offs_k = pid_k * K_BLOCK_SIZE + tl.arange(0, K_BLOCK_SIZE) - # compute masks if PADDED_HEAD: o_mask = offs_k < ACTUAL_BLOCK_DMODEL @@ -546,7 +675,11 @@ def _splitK_reduce( metadata_ptr = metadata_offset + offs_splitK * stride_ms + pid_m * stride_mm osk_offset = Out_splitK + pid_zhg * stride_osk_zhg + pid_m * stride_osk_m - osk_ptr = osk_offset + offs_splitK[:, None] * stride_osk_s + offs_k[None, :] * stride_osk_k + osk_ptr = ( + osk_offset + + offs_splitK[:, None] * stride_osk_s + + offs_k[None, :] * stride_osk_k + ) # read max values of each splitK if MASK_SPLITK: @@ -560,7 +693,7 @@ def _splitK_reduce( acc = tl.load(osk_ptr) g_m = tl.max(l_m, axis=0) - + alpha = tl.where(l_m > float("-inf"), tl.math.exp2(l_m - g_m), 0.0) # read sum @@ -569,21 +702,19 @@ def _splitK_reduce( acc = acc * alpha[:, None] g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0) - acc_out = tl.sum(acc, axis=0) / g_sum_safe + acc_out = tl.sum(acc, axis=0) / g_sum_safe # Store output z_id = pid_zhg // (H * G) h_id = (pid_zhg // G) % H g_id = pid_zhg % G - out_offset = Out + z_id * stride_oz + h_id * stride_oh + g_id * stride_og + out_offset = Out + z_id * stride_oz + h_id * stride_oh + g_id * stride_og out_ptr = out_offset + pid_m * stride_om + offs_k tl.store(out_ptr, acc_out, mask=o_mask) # Store lse l_ptrs = LSE + pid_zhg * stride_lse_zhg + pid_m - lse_val = tl.where(g_sum > 0, - (g_m + tl.math.log2(g_sum)) / 1.44269504, - g_m) + lse_val = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m) tl.store(l_ptrs, lse_val) @@ -596,6 +727,7 @@ def cast_uint32_to_half2(scale_shift): shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) return scale, shift + @triton.jit def dequantize( x_, @@ -605,14 +737,18 @@ def dequantize( ): # PACKED_PER_VAL is the number of values packed into # each element x_. For example, for int4 quantization - #and x_ of type int32, PACKED_PER_VAL is 8. + # and x_ of type int32, PACKED_PER_VAL is 8. BLOCK_N: tl.constexpr = x_.shape[0] BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] offsets = tl.arange(0, PACKED_PER_VAL) * 4 - quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + quant_offset = ( + x_[:, None, :] >> offsets[None, :, None] + ) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) - quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) + quant_offset = tl.view( + quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) + ) # Trick - instead of converting int4 to float16 we view it as float16 # and then multiply by 32768 * 512 == 2**24 quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) @@ -622,6 +758,7 @@ def dequantize( dequant = quant_offset * scale_512 + shift return dequant + def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: # Scale and shift are such that quantization linearly maps # int4 values range [0..15] to input values range min(k)..max(k) @@ -639,7 +776,9 @@ def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: in_bytes = in_bytes.to(torch.uint8) in_int4 = in_bytes & 0xF in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) - scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1) + scale_shift = torch.concat( + [scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1 + ) k_quant = torch.concat( [ scale_shift.flatten(start_dim=-2), @@ -656,7 +795,9 @@ def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tens ss_size = num_groups * 4 scale_shift_ui8 = k_ui8[..., 0:ss_size] - scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4) + scale_shift_ui8 = scale_shift_ui8.reshape( + *scale_shift_ui8.shape[:-1], num_groups, 4 + ) scale = scale_shift_ui8[..., 0:2].view(torch.float16) shift = scale_shift_ui8[..., 2:4].view(torch.float16) @@ -668,7 +809,11 @@ def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tens k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) - out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device) + out = torch.empty( + (*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), + dtype=torch.float16, + device=quant_k.device, + ) out[..., ::2] = k1_f16 out[..., 1::2] = k2_f16 out = out.reshape(*k_shape[:-2], -1) @@ -689,26 +834,52 @@ def get_split_k(B: int, G: int, H: int, Mk: int) -> int: split_k = max(split_k, 1) return split_k -def attention_decode_forward_triton_impl( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - k_new: Optional[torch.Tensor], - v_new: Optional[torch.Tensor], - out: torch.Tensor, - sm_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - alibi_slopes: Optional[torch.Tensor], - layout: Literal["bshd"], - cache_seqlens: Optional[torch.Tensor], - cache_batch_idx: Optional[torch.Tensor], - block_table: Optional[torch.Tensor] = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, + +def attention_forward_decode_triton_impl( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + out: torch.Tensor, + sm_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + alibi_slopes: Optional[torch.Tensor], + layout: Literal["bshd"], + cache_seqlens: Optional[torch.Tensor], + cache_batch_idx: Optional[torch.Tensor], + block_table: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + # rotary (optional) + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + rotary_interleaved: bool = False, + seqlens_rotary: Optional[torch.Tensor] = None, ): + # apply rotary embedding + if rotary_cos is not None and rotary_sin is not None: + # Prefer explicitly provided rotary sequence start offsets if given; fall back to cache_seqlens. + seqlen_offsets = ( + seqlens_rotary + if seqlens_rotary is not None + else (cache_seqlens if cache_seqlens is not None else 0) + ) + local = (window_size_left != -1) or (window_size_right != -1) + q, k_new = apply_rotary( + q, + k_new, + rotary_cos, + rotary_sin, + causal=causal, + local=local, + interleaved=rotary_interleaved, + seqlen_offsets=seqlen_offsets, + ) + # handle cache updates if k_new is not None and v_new is not None: # Update cache with new KV values @@ -716,7 +887,7 @@ def attention_decode_forward_triton_impl( # Non-paged attention: update cache directly batch_size = k_new.shape[0] seqlen_new = k_new.shape[1] - + if cache_seqlens is not None: # Use cache_seqlens to determine where to insert new KV for b in range(batch_size): @@ -728,14 +899,16 @@ def attention_decode_forward_triton_impl( else: # Append at the end of existing cache seqlen_cache = k_cache.shape[1] - k_cache[:, seqlen_cache - seqlen_new:] = k_new - v_cache[:, seqlen_cache - seqlen_new:] = v_new + k_cache[:, seqlen_cache - seqlen_new :] = k_new + v_cache[:, seqlen_cache - seqlen_new :] = v_new else: # Paged attention: update cache using block table batch_size = k_new.shape[0] seqlen_new = k_new.shape[1] - block_size = k_cache.shape[1] # k_cache shape: [num_blocks, block_size, nheads, head_dim] - + block_size = k_cache.shape[ + 1 + ] # k_cache shape: [num_blocks, block_size, nheads, head_dim] + # Update cache for each batch element for b in range(batch_size): if cache_seqlens is not None: @@ -750,35 +923,37 @@ def attention_decode_forward_triton_impl( else: start_idx = block_idx * block_size break - + # Copy new KV values into the paged cache for i in range(seqlen_new): pos = start_idx + i block_idx = pos // block_size within_block_idx = pos % block_size - + # Get the physical block number from block table if block_idx < block_table.shape[1]: physical_block = int(block_table[b, block_idx].item()) - + # Update k_cache and v_cache at the physical block location k_cache[physical_block, within_block_idx] = k_new[b, i] v_cache[physical_block, within_block_idx] = v_new[b, i] - + # Update cache_seqlens if provided if cache_seqlens is not None: cache_seqlens[b] = start_idx + seqlen_new - + # triton configs BLOCK_M = 16 BLOCK_N = 64 num_stages = 1 num_warps_fwd = 1 num_warps_reduce = 4 - + # kernel_configs is_new_kv = False # Cache has been updated, so no new KV in kernel - use_alibi, (stride_az, stride_ah) = True if alibi_slopes is not None else False, alibi_slopes.stride() if alibi_slopes is not None else (None, None) + use_alibi, (stride_az, stride_ah) = True if alibi_slopes is not None else False, ( + alibi_slopes.stride() if alibi_slopes is not None else (None, None) + ) use_cache_seqlens = cache_seqlens is not None use_sliding_window = window_size_left != -1 or window_size_right != -1 use_block_table = block_table is not None @@ -786,8 +961,13 @@ def attention_decode_forward_triton_impl( NUM_QUANT_GROUPS = 1 # get shapes and strides - (batch_size, seqlen_q, nheads_q, dim_q), (stride_qz, stride_qh, stride_qm, stride_qd) = get_shape_and_strides_from_layout(q, layout) - + (batch_size, seqlen_q, nheads_q, dim_q), ( + stride_qz, + stride_qh, + stride_qm, + stride_qd, + ) = get_shape_and_strides_from_layout(q, layout) + # Handle paged KV cache layout if use_block_table: # For paged attention, k_cache and v_cache have shape [num_blocks, block_size, nheads, head_dim] @@ -801,29 +981,63 @@ def attention_decode_forward_triton_impl( num_blocks_per_seq = block_table.shape[1] seqlen_kc = num_blocks_per_seq * block_size_k seqlen_vc = seqlen_kc - + # Strides for paged layout stride_kc_z = 0 # No batch dimension in paged cache stride_kc_n = k_cache.stride(1) # Sequence stride stride_kc_h = k_cache.stride(2) # Head stride stride_kc_d = k_cache.stride(3) # Dim stride - + stride_vc_z = 0 stride_vc_n = v_cache.stride(1) stride_vc_h = v_cache.stride(2) stride_vc_d = v_cache.stride(3) else: - (_, seqlen_kc, nheads_kc, dim_kc), (stride_kc_z, stride_kc_h, stride_kc_n, stride_kc_d) = get_shape_and_strides_from_layout(k_cache, layout) - (_, seqlen_vc, nheads_vc, dim_vc), (stride_vc_z, stride_vc_h, stride_vc_n, stride_vc_d) = get_shape_and_strides_from_layout(v_cache, layout) + (_, seqlen_kc, nheads_kc, dim_kc), ( + stride_kc_z, + stride_kc_h, + stride_kc_n, + stride_kc_d, + ) = get_shape_and_strides_from_layout(k_cache, layout) + (_, seqlen_vc, nheads_vc, dim_vc), ( + stride_vc_z, + stride_vc_h, + stride_vc_n, + stride_vc_d, + ) = get_shape_and_strides_from_layout(v_cache, layout) block_size_k = 0 # Not used if is_new_kv: - ( _, seqlen_kn, nheads_kn, dim_kn), (stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d) = get_shape_and_strides_from_layout(k_new, layout) - (_, seqlen_vn, nheads_vn, dim_vn), (stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d) = get_shape_and_strides_from_layout(v_new, layout) + (_, seqlen_kn, nheads_kn, dim_kn), ( + stride_kn_z, + stride_kn_h, + stride_kn_n, + stride_kn_d, + ) = get_shape_and_strides_from_layout(k_new, layout) + (_, seqlen_vn, nheads_vn, dim_vn), ( + stride_vn_z, + stride_vn_h, + stride_vn_n, + stride_vn_d, + ) = get_shape_and_strides_from_layout(v_new, layout) else: - ( _, seqlen_kn, nheads_kn, dim_kn), (stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d) = (None, None, None, None), (None, None, None, None) - (_, seqlen_vn, nheads_vn, dim_vn), (stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d) = (None, None, None, None), (None, None, None, None) - (_, seqlen_o, nheads_o, dim_o), (stride_oz, stride_oh, stride_om, stride_od) = get_shape_and_strides_from_layout(out, layout) - assert dim_q == dim_kc == dim_vc, f"Dimensions must match: {dim_q}, {dim_kc}, {dim_vc}" + (_, seqlen_kn, nheads_kn, dim_kn), ( + stride_kn_z, + stride_kn_h, + stride_kn_n, + stride_kn_d, + ) = (None, None, None, None), (None, None, None, None) + (_, seqlen_vn, nheads_vn, dim_vn), ( + stride_vn_z, + stride_vn_h, + stride_vn_n, + stride_vn_d, + ) = (None, None, None, None), (None, None, None, None) + (_, seqlen_o, nheads_o, dim_o), (stride_oz, stride_oh, stride_om, stride_od) = ( + get_shape_and_strides_from_layout(out, layout) + ) + assert ( + dim_q == dim_kc == dim_vc + ), f"Dimensions must match: {dim_q}, {dim_kc}, {dim_vc}" # add extra information needed by the kernels if layout == "bshd": @@ -841,7 +1055,7 @@ def attention_decode_forward_triton_impl( raise ValueError(f"{layout} layout is not supported") # get padded size - dim_padded = get_padded_headsize(dim_kc) + dim_padded = get_padded_headsize(dim_kc) is_padded_head = dim_padded != dim_kc # Handle MQA/GQA case @@ -857,7 +1071,11 @@ def attention_decode_forward_triton_impl( # Use heuristics if use_block_table: # For paged attention, use the actual sequence length from cache_seqlens - max_seqlen = int(cache_seqlens.max().item()) if cache_seqlens is not None else block_size_k + max_seqlen = ( + int(cache_seqlens.max().item()) + if cache_seqlens is not None + else block_size_k + ) split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, max_seqlen) else: split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) @@ -865,37 +1083,63 @@ def attention_decode_forward_triton_impl( # setup grid seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M - grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch_size * n_group_q * heads_per_group_q, split_k) - + grid = lambda META: ( + triton.cdiv(seqlen_q, META["BLOCK_M"]), + batch_size * n_group_q * heads_per_group_q, + split_k, + ) + # create intermediate tensors - out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_kc], dtype=torch.float32, device=q.device) - metadata = torch.empty([batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device) - lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), dtype=torch.float32, device=q.device) - + out_splitk = torch.empty( + [batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_kc], + dtype=torch.float32, + device=q.device, + ) + metadata = torch.empty( + [batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], + dtype=torch.float32, + device=q.device, + ) + lse = torch.empty( + (batch_size * n_group_q * heads_per_group_q, seqlen_q), + dtype=torch.float32, + device=q.device, + ) + # get intermediate tensor strides stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d = out_splitk.stride() stride_mzhg, stride_m2, stride_ms, stride_mm = metadata.stride() stride_lse_zhg, stride_lse_m = lse.stride() - + # Block table strides if use_block_table: stride_bt_b, stride_bt_s = block_table.stride() else: stride_bt_b, stride_bt_s = 0, 0 - + # FP8 support IS_FP8 = is_fp8(q) if IS_FP8: if (q_descale is None) or (k_descale is None) or (v_descale is None): import warnings - warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) + + warnings.warn( + "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", + UserWarning, + ) # Create default descale tensors if not provided if q_descale is None: - q_descale = torch.ones(batch_size, nheads_q, dtype=torch.float32, device=q.device) + q_descale = torch.ones( + batch_size, nheads_q, dtype=torch.float32, device=q.device + ) if k_descale is None: - k_descale = torch.ones(batch_size, nheads_kc, dtype=torch.float32, device=q.device) + k_descale = torch.ones( + batch_size, nheads_kc, dtype=torch.float32, device=q.device + ) if v_descale is None: - v_descale = torch.ones(batch_size, nheads_vc, dtype=torch.float32, device=q.device) + v_descale = torch.ones( + batch_size, nheads_vc, dtype=torch.float32, device=q.device + ) stride_q_descale_z, stride_q_descale_h = q_descale.stride() stride_k_descale_z, stride_k_descale_h = k_descale.stride() stride_v_descale_z, stride_v_descale_h = v_descale.stride() @@ -911,18 +1155,45 @@ def attention_decode_forward_triton_impl( stride_v_descale_h = 0 if DEBUG: - print("batch_size, seqlen_q, nheads_q, dim_q", (batch_size, seqlen_q, nheads_q, dim_q)) + print( + "batch_size, seqlen_q, nheads_q, dim_q", + (batch_size, seqlen_q, nheads_q, dim_q), + ) print("_, seqlen_kc, nheads_kc, dim_kc", (_, seqlen_kc, nheads_kc, dim_kc)) print("dim_padded:", dim_padded) - print("stride_qz, stride_qm, stride_qg, stride_qh, stride_qd", (stride_qz, stride_qm, stride_qg, stride_qh, stride_qd)) - print("stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d", (stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d)) - print("stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d", (stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d)) + print( + "stride_qz, stride_qm, stride_qg, stride_qh, stride_qd", + (stride_qz, stride_qm, stride_qg, stride_qh, stride_qd), + ) + print( + "stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d", + (stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d), + ) + print( + "stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d", + (stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d), + ) if is_new_kv: - print("stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d", (stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d)) - print("stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d", (stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d)) - print("stride_oz, stride_om, stride_og, stride_oh, stride_od", (stride_oz, stride_om, stride_og, stride_oh, stride_od)) - print("stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d", (stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d)) - print("stride_mzhg, stride_m2, stride_ms, stride_mm", (stride_mzhg, stride_m2, stride_ms, stride_mm)) + print( + "stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d", + (stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d), + ) + print( + "stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d", + (stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d), + ) + print( + "stride_oz, stride_om, stride_og, stride_oh, stride_od", + (stride_oz, stride_om, stride_og, stride_oh, stride_od), + ) + print( + "stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d", + (stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d), + ) + print( + "stride_mzhg, stride_m2, stride_ms, stride_mm", + (stride_mzhg, stride_m2, stride_ms, stride_mm), + ) print("stride_lse_zhg, stride_lse_m", (stride_lse_zhg, stride_lse_m)) _fwd_kernel_splitK[grid]( @@ -1042,7 +1313,6 @@ def attention_decode_forward_triton_impl( k_block_size = dim_padded // k_block_num grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num) - if DEBUG: print("splitK_pow2:", splitK_pow2) print("k_block_num:", k_block_num) @@ -1050,10 +1320,10 @@ def attention_decode_forward_triton_impl( print("grid:", grid) _splitK_reduce[grid]( - out_splitk, - metadata, - out, - lse, + out_splitk, + metadata, + out, + lse, # Split-K output strides stride_osk_zhg=stride_osk_zhg, stride_osk_s=stride_osk_s, @@ -1076,13 +1346,14 @@ def attention_decode_forward_triton_impl( K_BLOCK_SIZE=k_block_size, BLOCK_DMODEL=dim_padded, ACTUAL_BLOCK_DMODEL=dim_kc, - G=n_group_q, + G=n_group_q, H=heads_per_group_q, # TODO: Tune num_warps - split_k=split_k, - splitK_pow2=splitK_pow2, + split_k=split_k, + splitK_pow2=splitK_pow2, MASK_SPLITK=mask_split_k, PADDED_HEAD=is_padded_head, - num_warps=num_warps_reduce) + num_warps=num_warps_reduce, + ) - return lse \ No newline at end of file + return lse.view(batch_size, n_group_q * heads_per_group_q, seqlen_q) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index bb0301c7700..d1036f98c3f 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -2,54 +2,130 @@ import torch import triton import triton.language as tl -from typing import Literal, Optional, Union -from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_arch, is_cdna, is_fp8, is_rdna, create_dropout_mask, DEBUG +from typing import Literal, Optional +from .utils import ( + DEBUG, + AUTOTUNE, + DROPOUT_USE_PYTORCH, + DROPOUT_DUMP, + compute_alibi_block, + compute_fp8_scaling_factors, + get_arch, + is_cdna, + is_fp8, + is_rdna, + create_dropout_mask, + apply_rotary, +) # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) + # ------------------------------- # Autotune # ------------------------------- def get_fwd_prefill_cdna_autotune_configs(): return [ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL_QK', 'ACTUAL_BLOCK_DMODEL_V', 'IS_VARLEN', 'HQ', 'HK'] + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + ], [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL_QK", + "ACTUAL_BLOCK_DMODEL_V", + "IS_VARLEN", + "HQ", + "HK", + ] def get_fwd_prefill_rdna_autotune_configs(): return [ - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 4, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 4, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 4, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=2, + ), # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL_QK', 'ACTUAL_BLOCK_DMODEL_V', 'IS_VARLEN', 'HQ', 'HK'] + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=2, + ), + ], [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL_QK", + "ACTUAL_BLOCK_DMODEL_V", + "IS_VARLEN", + "HQ", + "HK", + ] def get_fwd_prefill_autotune_configs(): @@ -64,11 +140,18 @@ def get_fwd_prefill_autotune_configs(): arch = get_arch() if arch == "gfx950": default_config = triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, num_stages=1, num_warps=4, ) - elif arch == "gfx942" and False: # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. + elif ( + arch == "gfx942" and False + ): # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. default_config = triton.Config( {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, @@ -80,10 +163,8 @@ def get_fwd_prefill_autotune_configs(): num_stages=1, num_warps=4, ) - - return [ - default_config - ], [ + + return [default_config], [ "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", @@ -96,27 +177,64 @@ def get_fwd_prefill_autotune_configs(): ] -fwd_prefill_autotune_configs, fwd_prefill_autotune_keys = get_fwd_prefill_autotune_configs() +fwd_prefill_autotune_configs, fwd_prefill_autotune_keys = ( + get_fwd_prefill_autotune_configs() +) + @triton.jit -def _attn_fwd_no_mask(acc, l_i, m_i, - q, k_base_ptrs, v_base_ptrs, bias_base_ptrs, - stride_kn, stride_vk, stride_bn, stride_sn, - start_m, seqlen_k, seqlen_q, - dropout_p, philox_seed, philox_base_ptrs, - sd_mask_base_ptrs, dropout_mask_base_ptrs, - offs_m, offs_n, offs_d_qk, offs_d_v, - block_min, block_max, alibi_slope, - q_descale, k_descale, v_descale, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_DMODEL_QK: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr, BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD_QK: tl.constexpr, PADDED_HEAD_V: tl.constexpr, - ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, ACTUAL_BLOCK_DMODEL_V: tl.constexpr, - SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, - RETURN_SCORES: tl.constexpr, ACCUMULATOR_TYPE): +def _attn_fwd_no_mask( + acc, + l_i, + m_i, + q, + k_base_ptrs, + v_base_ptrs, + bias_base_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_base_ptrs, + sd_mask_base_ptrs, + dropout_mask_base_ptrs, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, + block_max, + alibi_slope, + q_descale, + k_descale, + v_descale, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_P_DESCALE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + PADDED_HEAD_QK: tl.constexpr, + PADDED_HEAD_V: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, + ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + SM_SCALE: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + RETURN_SCORES: tl.constexpr, + ACCUMULATOR_TYPE, +): if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 - + # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # get ptrs @@ -128,37 +246,46 @@ def _attn_fwd_no_mask(acc, l_i, m_i, k_mask, k_mask_other = (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK), 0.0 else: k_mask, k_mask_other = None, None - + if PADDED_HEAD_V: v_mask, v_mask_other = (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V), 0.0 else: v_mask, v_mask_other = None, None - + # load k and if preload_v then v - k = tl.load(k_ptrs, mask=k_mask, other=k_mask_other) if PADDED_HEAD_QK else tl.load(k_ptrs) + k = ( + tl.load(k_ptrs, mask=k_mask, other=k_mask_other) + if PADDED_HEAD_QK + else tl.load(k_ptrs) + ) if PRE_LOAD_V: - v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD_V else tl.load(v_ptrs) - + v = ( + tl.load(v_ptrs, mask=v_mask, other=v_mask_other) + if PADDED_HEAD_V + else tl.load(v_ptrs) + ) + # setup qk accumlator qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) # -- compute qk ---- - if IS_FP8 : - qk += (tl.dot(q, k) * q_descale * k_descale) + if IS_FP8: + qk += tl.dot(q, k) * q_descale * k_descale else: qk += tl.dot(q, k) - qk_scaled = qk * SM_SCALE + qk_scaled = qk * SM_SCALE if USE_ALIBI: # compute the global position of each token within the sequence q_offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - alibi_block = compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, q_offs_m, - kv_offs_n) + alibi_block = compute_alibi_block( + alibi_slope, seqlen_q, seqlen_k, q_offs_m, kv_offs_n + ) qk_scaled += alibi_block # compute qk mask qk_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) - + # compute bias if bias_base_ptrs is not None: bias_ptrs = bias_base_ptrs + start_n * stride_bn @@ -169,10 +296,10 @@ def _attn_fwd_no_mask(acc, l_i, m_i, m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) # scale and subtract max - q_shifted = tl.where(m_ij[:, None] == float("-inf"), - float("-inf"), - qk_scaled - m_ij[:, None]) - + q_shifted = tl.where( + m_ij[:, None] == float("-inf"), float("-inf"), qk_scaled - m_ij[:, None] + ) + # Compute scaled QK and softmax probabilities if USE_EXP2: p = tl.math.exp2(q_shifted * RCP_LN2) @@ -188,7 +315,9 @@ def _attn_fwd_no_mask(acc, l_i, m_i, if tl_DROPOUT_USE_PYTORCH: dropout_mask = tl.load(dropout_mask_ptrs, mask=qk_mask) else: - rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance + rng_output = tl.rand( + philox_seed, philox_ptrs + ) # TODO: use tl.randint for better performance dropout_mask = rng_output > dropout_p if tl_DROPOUT_DUMP: tl.store(dropout_mask_ptrs, dropout_mask, mask=qk_mask) @@ -203,21 +332,23 @@ def _attn_fwd_no_mask(acc, l_i, m_i, # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn tl.store(sd_mask_ptrs, p, mask=qk_mask) - + # -- update output accumulator -- # alpha is an adjustment factor for acc and li as we loop and find new maxes # store the diff in maxes to adjust acc and li as we discover new maxes - m_diff = tl.where(m_ij == float("-inf"), - float("-inf"), - m_i - m_ij) + m_diff = tl.where(m_ij == float("-inf"), float("-inf"), m_i - m_ij) if USE_EXP2: alpha = tl.math.exp2(m_diff * RCP_LN2) else: alpha = tl.math.exp(m_diff) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD_V else tl.load(v_ptrs) - + v = ( + tl.load(v_ptrs, mask=v_mask, other=v_mask_other) + if PADDED_HEAD_V + else tl.load(v_ptrs) + ) + # -- update m_i and l_i l_i = l_i * alpha + l_ij m_i = m_ij @@ -225,38 +356,80 @@ def _attn_fwd_no_mask(acc, l_i, m_i, if IS_FP8: if FP8_P_DESCALE: scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * v_descale) + acc += ( + tl.dot((p * scale_p).to(v.type.element_ty), v) + * descale_p + * v_descale + ) else: acc += tl.dot(p.to(v.type.element_ty), v) * v_descale else: acc += tl.dot(p.to(v.type.element_ty), v) - + return acc, l_i, m_i + @triton.jit -def _attn_fwd_mask(acc, l_i, m_i, - q, k_base_ptrs, v_base_ptrs, bias_base_ptrs, - stride_kn, stride_vk, stride_bn, stride_sn, start_m, - seqlen_k, seqlen_q, - dropout_p, philox_seed, philox_base_ptrs, - sd_mask_base_ptrs, dropout_mask_base_ptrs, - offs_m, offs_n, offs_d_qk, offs_d_v, - block_min, block_max, n_extra_tokens, alibi_slope, - q_descale, k_descale, v_descale, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, - IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL_QK: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr, BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD_QK: tl.constexpr, PADDED_HEAD_V: tl.constexpr, - ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, ACTUAL_BLOCK_DMODEL_V: tl.constexpr, - SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, - RETURN_SCORES: tl.constexpr, - USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, - ACCUMULATOR_TYPE): +def _attn_fwd_mask( + acc, + l_i, + m_i, + q, + k_base_ptrs, + v_base_ptrs, + bias_base_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_base_ptrs, + sd_mask_base_ptrs, + dropout_mask_base_ptrs, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, + block_max, + n_extra_tokens, + alibi_slope, + q_descale, + k_descale, + v_descale, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_P_DESCALE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + PADDED_HEAD_QK: tl.constexpr, + PADDED_HEAD_V: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, + ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + SM_SCALE: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + RETURN_SCORES: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + ACCUMULATOR_TYPE, +): if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 # seqlen diff seqlen_delta_qk = seqlen_k - seqlen_q - + # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # get ptrs @@ -266,21 +439,21 @@ def _attn_fwd_mask(acc, l_i, m_i, # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. kv_offs_n = start_n + tl.arange(0, BLOCK_N) - k_mask = (kv_offs_n[None, :] < seqlen_k) - v_mask = (kv_offs_n[:, None] < seqlen_k) + k_mask = kv_offs_n[None, :] < seqlen_k + v_mask = kv_offs_n[:, None] < seqlen_k if PADDED_HEAD_QK: k_mask = k_mask & (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK) if PADDED_HEAD_V: v_mask = v_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) - + # load k and if preload_v then v - k = tl.load(k_ptrs, mask=k_mask, other = 0.0) + k = tl.load(k_ptrs, mask=k_mask, other=0.0) if PRE_LOAD_V: v = tl.load(v_ptrs, mask=v_mask, other=0.0) - + # setup qk accumlator qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) - + # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. @@ -296,17 +469,18 @@ def _attn_fwd_mask(acc, l_i, m_i, qk = tl.where(mask, qk, float("-inf")) # -- compute qk ---- - if IS_FP8 : - qk += (tl.dot(q, k) * q_descale * k_descale) + if IS_FP8: + qk += tl.dot(q, k) * q_descale * k_descale else: qk += tl.dot(q, k) - qk_scaled = qk * SM_SCALE + qk_scaled = qk * SM_SCALE if USE_ALIBI: # compute the global position of each token within the sequence q_offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - alibi_block = compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, q_offs_m, - kv_offs_n) + alibi_block = compute_alibi_block( + alibi_slope, seqlen_q, seqlen_k, q_offs_m, kv_offs_n + ) qk_scaled += alibi_block if USE_SLIDING_WINDOW: @@ -315,35 +489,39 @@ def _attn_fwd_mask(acc, l_i, m_i, # For causal sliding window, we need to apply both constraints: # 1. Causal: col_idx <= row_idx + (seqlen_k - seqlen_q) # 2. Sliding window: row_idx - window_left <= col_idx <= row_idx + window_right - + # Get positions row_idx = offs_m # Query positions col_idx = kv_offs_n # Key positions - + # Expand for broadcasting row_idx_expanded = row_idx[:, None] # [BLOCK_M, 1] col_idx_expanded = col_idx[None, :] # [1, BLOCK_N] - + # Apply causal constraint: can only attend to positions before or at the diagonal causal_offset = seqlen_k - seqlen_q causal_mask = col_idx_expanded > (row_idx_expanded + causal_offset) - + # Apply sliding window constraint if WINDOW_SIZE_LEFT < 0: # Only right window constraint - window_mask = col_idx_expanded > (row_idx_expanded + causal_offset + WINDOW_SIZE_RIGHT) + window_mask = col_idx_expanded > ( + row_idx_expanded + causal_offset + WINDOW_SIZE_RIGHT + ) else: # Both left and right window constraints # Adjust window bounds by causal offset left_bound = row_idx_expanded + causal_offset - WINDOW_SIZE_LEFT right_bound = row_idx_expanded + causal_offset + WINDOW_SIZE_RIGHT - + # Can't attend to positions outside the window - window_mask = (col_idx_expanded < left_bound) | (col_idx_expanded > right_bound) - + window_mask = (col_idx_expanded < left_bound) | ( + col_idx_expanded > right_bound + ) + # Final mask is the union of both constraints (True = cannot attend) mask = causal_mask | window_mask - + # Apply mask qk_scaled = tl.where(mask, float("-inf"), qk_scaled) else: @@ -351,25 +529,27 @@ def _attn_fwd_mask(acc, l_i, m_i, # Exactly matching reference construct_local_mask: # row_idx = query positions, col_idx = key positions # sk = seqlen_k, sq = seqlen_q - - # Get positions + + # Get positions row_idx = offs_m # Query positions col_idx = kv_offs_n # Key positions - + # sk and sq from reference (no padding masks in this test) sk = seqlen_k sq = seqlen_q - + # Expand for broadcasting row_idx_expanded = row_idx[:, None] # [BLOCK_M, 1] col_idx_expanded = col_idx[None, :] # [1, BLOCK_N] - + # Reference logic for mask computation if WINDOW_SIZE_LEFT < 0: # Reference: return col_idx > row_idx + sk - sq + window_size[1] - mask = col_idx_expanded > (row_idx_expanded + sk - sq + WINDOW_SIZE_RIGHT) + mask = col_idx_expanded > ( + row_idx_expanded + sk - sq + WINDOW_SIZE_RIGHT + ) else: - # Reference: + # Reference: # sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk # return torch.logical_or( # col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), @@ -378,15 +558,17 @@ def _attn_fwd_mask(acc, l_i, m_i, # Create sk tensor with proper shape for broadcasting # sk represents the key sequence length, which should be compared per column sk_full = tl.full((1, BLOCK_N), sk, dtype=tl.int32) - + # Compute boundaries right_bound_val = row_idx_expanded + sk - sq + WINDOW_SIZE_RIGHT right_bound = tl.minimum(right_bound_val, sk_full) left_bound = row_idx_expanded + sk - sq - WINDOW_SIZE_LEFT - + # Mask where True = cannot attend (matching reference) - mask = (col_idx_expanded > right_bound) | (col_idx_expanded < left_bound) - + mask = (col_idx_expanded > right_bound) | ( + col_idx_expanded < left_bound + ) + # Apply mask (set to -inf where mask is True) qk_scaled = tl.where(mask, float("-inf"), qk_scaled) else: @@ -394,7 +576,7 @@ def _attn_fwd_mask(acc, l_i, m_i, causal_boundary = start_n + offs_n - seqlen_delta_qk causal_mask = offs_m[:, None] >= causal_boundary[None, :] qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) - + # compute qk mask qk_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) @@ -414,12 +596,12 @@ def _attn_fwd_mask(acc, l_i, m_i, if USE_SLIDING_WINDOW: # Check if this block has any valid values (m_ij != -inf) # For rows where everything is -inf, set q_shifted to -inf (not NaN) - q_shifted = tl.where(m_ij[:, None] == float("-inf"), - float("-inf"), - qk_scaled - m_ij[:, None]) + q_shifted = tl.where( + m_ij[:, None] == float("-inf"), float("-inf"), qk_scaled - m_ij[:, None] + ) else: q_shifted = qk_scaled - m_ij[:, None] - + # Compute scaled QK and softmax probabilities if USE_EXP2: p = tl.math.exp2(q_shifted * RCP_LN2) @@ -435,7 +617,9 @@ def _attn_fwd_mask(acc, l_i, m_i, if tl_DROPOUT_USE_PYTORCH: dropout_mask = tl.load(dropout_mask_ptrs, mask=qk_mask) else: - rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance + rng_output = tl.rand( + philox_seed, philox_ptrs + ) # TODO: use tl.randint for better performance dropout_mask = rng_output > dropout_p if tl_DROPOUT_DUMP: tl.store(dropout_mask_ptrs, dropout_mask, mask=qk_mask) @@ -450,13 +634,11 @@ def _attn_fwd_mask(acc, l_i, m_i, # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn tl.store(sd_mask_ptrs, p, mask=qk_mask) - + # -- update output accumulator -- # alpha is an adjustment factor for acc and li as we loop and find new maxes # store the diff in maxes to adjust acc and li as we discover new maxes - m_diff = tl.where(m_ij == float("-inf"), - float("-inf"), - m_i - m_ij) + m_diff = tl.where(m_ij == float("-inf"), float("-inf"), m_i - m_ij) if USE_EXP2: alpha = tl.math.exp2(m_diff * RCP_LN2) else: @@ -464,7 +646,7 @@ def _attn_fwd_mask(acc, l_i, m_i, acc = acc * alpha[:, None] if not PRE_LOAD_V: v = tl.load(v_ptrs, mask=v_mask, other=0.0) - + # -- update m_i and l_i l_i = l_i * alpha + l_ij m_i = m_ij @@ -472,20 +654,29 @@ def _attn_fwd_mask(acc, l_i, m_i, if IS_FP8: if FP8_P_DESCALE: scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * v_descale) + acc += ( + tl.dot((p * scale_p).to(v.type.element_ty), v) + * descale_p + * v_descale + ) else: acc += tl.dot(p.to(v.type.element_ty), v) * v_descale else: acc += tl.dot(p.to(v.type.element_ty), v) - + return acc, l_i, m_i @triton.jit -def compute_window_bounds(q_start, q_end, diag, seqlen_k, - WINDOW_SIZE_LEFT: tl.constexpr, - WINDOW_SIZE_RIGHT: tl.constexpr, - IS_CAUSAL: tl.constexpr): +def compute_window_bounds( + q_start, + q_end, + diag, + seqlen_k, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + IS_CAUSAL: tl.constexpr, +): """Calculate the window boundaries for a query block.""" # Left boundary if WINDOW_SIZE_LEFT < 0: @@ -494,8 +685,8 @@ def compute_window_bounds(q_start, q_end, diag, seqlen_k, else: left_min = tl.maximum(0, q_start + diag - WINDOW_SIZE_LEFT) left_max = tl.maximum(0, q_end + diag - WINDOW_SIZE_LEFT) - - # Right boundary + + # Right boundary if IS_CAUSAL: # Causal cap: col ≤ row + diag right_min = tl.minimum(seqlen_k - 1, q_start + diag) @@ -508,41 +699,54 @@ def compute_window_bounds(q_start, q_end, diag, seqlen_k, # Non-causal doesn't have the diagonal constraint right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) - + return left_min, left_max, right_min, right_max + @triton.jit -def classify_window_blocks(left_min, left_max, right_min, right_max, - BLOCK_N: tl.constexpr): +def classify_window_blocks( + left_min, left_max, right_min, right_max, BLOCK_N: tl.constexpr +): """Classify blocks based on window boundaries.""" # First and last blocks that have ANY overlap with window first_block = left_min // BLOCK_N last_block = right_max // BLOCK_N - + # First block that is FULLY visible for all rows in Q block full_left_block = left_max // BLOCK_N + (left_max % BLOCK_N != 0) clipped_left = tl.minimum(full_left_block, last_block + 1) - + # Last block that is FULLY visible for all rows in Q block last_full_block_candidate = right_min // BLOCK_N if (last_full_block_candidate + 1) * BLOCK_N - 1 > right_min: last_full_block_candidate -= 1 full_right_block = tl.maximum(last_full_block_candidate, clipped_left - 1) - + # Calculate counts n_front_skip_blocks = first_block n_front_masked_blocks = tl.maximum(0, clipped_left - first_block) n_full_blocks = tl.maximum(0, full_right_block - clipped_left + 1) n_back_masked_blocks = tl.maximum(0, last_block - full_right_block) - - return (n_front_skip_blocks, n_front_masked_blocks, - n_full_blocks, n_back_masked_blocks, - clipped_left) # Return clipped_left for padded block handling + + return ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + clipped_left, + ) # Return clipped_left for padded block handling + @triton.jit -def handle_padded_last_block(n_extra_tokens, last_block, total_k_blocks, - clipped_left, n_front_masked_blocks, - n_full_blocks, n_back_masked_blocks): +def handle_padded_last_block( + n_extra_tokens, + last_block, + total_k_blocks, + clipped_left, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, +): """Ensure a padded last K-block is never classified as 'full'. We move the padded last block (if visible) into the back-masked bucket. @@ -570,6 +774,7 @@ def handle_padded_last_block(n_extra_tokens, last_block, total_k_blocks, return n_front_masked_blocks, n_full_blocks, n_back_masked_blocks + @triton.jit def compute_padding_info(seqlen_k, BLOCK_N: tl.constexpr): """Calculate padding information for the last K block.""" @@ -589,14 +794,22 @@ def compute_padding_info(seqlen_k, BLOCK_N: tl.constexpr): n_extra_tokens = 0 return n_extra_tokens + @triton.jit -def compute_block_masking(seqlen_k, seqlen_q, start_m, - IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, - WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): +def compute_block_masking( + seqlen_k, + seqlen_q, + start_m, + IS_CAUSAL: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): """ Classify K blocks for attention computation with sliding window support. - + Returns: - n_front_skip_blocks: Blocks completely before the window - n_front_masked_blocks: Blocks partially overlapping window front @@ -607,39 +820,57 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, # common q_start = start_m * BLOCK_M - q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) - diag = seqlen_k - seqlen_q + q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) + diag = seqlen_k - seqlen_q total_k_blocks = tl.cdiv(seqlen_k, BLOCK_N) n_extra_tokens = compute_padding_info(seqlen_k, BLOCK_N) - + if USE_SLIDING_WINDOW: # get window bounds left_min, left_max, right_min, right_max = compute_window_bounds( - q_start, q_end, diag, seqlen_k, - WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, IS_CAUSAL + q_start, + q_end, + diag, + seqlen_k, + WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT, + IS_CAUSAL, ) # window vanishes → early exit if right_max < left_min: return 0, 0, 0, 0, n_extra_tokens - + # classify blocks - (n_front_skip_blocks, n_front_masked_blocks, - n_full_blocks, n_back_masked_blocks, - clipped_left) = classify_window_blocks( - left_min, left_max, right_min, right_max, BLOCK_N - ) - + ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + clipped_left, + ) = classify_window_blocks(left_min, left_max, right_min, right_max, BLOCK_N) + # handle padded last block if needed if n_extra_tokens != 0: last_block = right_max // BLOCK_N - n_front_masked_blocks, n_full_blocks, n_back_masked_blocks = handle_padded_last_block( - n_extra_tokens, last_block, total_k_blocks, - clipped_left, n_front_masked_blocks, - n_full_blocks, n_back_masked_blocks + n_front_masked_blocks, n_full_blocks, n_back_masked_blocks = ( + handle_padded_last_block( + n_extra_tokens, + last_block, + total_k_blocks, + clipped_left, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + ) ) - return (n_front_skip_blocks, n_front_masked_blocks, - n_full_blocks, n_back_masked_blocks, n_extra_tokens) + return ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + n_extra_tokens, + ) else: if IS_CAUSAL: # ========== CAUSAL MODE: Classify K Blocks ========== @@ -660,19 +891,19 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, # 1. figure out, in tokens, the right-most K position # this Q-block may attend to # ------------------------------------------------------------ - k_max_token = q_end + diag # last visible K index + k_max_token = q_end + diag # last visible K index # this Q-block is entirely above the diagonal ⇒ nothing to do if k_max_token < 0: return 0, 0, 0, 0, n_extra_tokens - k_max_token = tl.minimum(k_max_token, seqlen_k - 1) + k_max_token = tl.minimum(k_max_token, seqlen_k - 1) # ------------------------------------------------------------ # 2. translate token indices into K-block indices # ------------------------------------------------------------ last_visible_k_block = k_max_token // BLOCK_N - n_visible_k_blocks = tl.minimum(last_visible_k_block + 1, total_k_blocks) + n_visible_k_blocks = tl.minimum(last_visible_k_block + 1, total_k_blocks) # ------------------------------------------------------------ # 3. classify those visible blocks @@ -685,14 +916,14 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, # middle of a K-block or the last K-block is padded # ------------------------------------------------------------ padded_last_k = n_extra_tokens != 0 - is_modulo_mn = (not padded_last_k) & (seqlen_q % BLOCK_M == 0) + is_modulo_mn = (not padded_last_k) & (seqlen_q % BLOCK_M == 0) n_back_masked_blocks = BLOCK_M // BLOCK_N + tl.where(is_modulo_mn, 0, 1) n_back_masked_blocks = tl.minimum(n_back_masked_blocks, n_visible_k_blocks) - n_front_skip_blocks = 0 # causal never skips the left side - n_front_masked_blocks = 0 # ditto - n_full_blocks = n_visible_k_blocks - n_back_masked_blocks + n_front_skip_blocks = 0 # causal never skips the left side + n_front_masked_blocks = 0 # ditto + n_full_blocks = n_visible_k_blocks - n_back_masked_blocks else: # ========== NON-CAUSAL MODE ========== # Without causal mask, all positions can attend to all positions @@ -707,17 +938,24 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] - - n_front_skip_blocks = 0 # never skips the left side - n_front_masked_blocks = 0 # ditto + + n_front_skip_blocks = 0 # never skips the left side + n_front_masked_blocks = 0 # ditto if n_extra_tokens != 0: n_back_masked_blocks = 1 # Last block needs padding mask n_full_blocks = total_k_blocks - 1 else: n_back_masked_blocks = 0 # All blocks are aligned n_full_blocks = total_k_blocks - - return n_front_skip_blocks, n_front_masked_blocks, n_full_blocks, n_back_masked_blocks, n_extra_tokens + + return ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + n_extra_tokens, + ) + @triton.autotune( configs=fwd_prefill_autotune_configs, @@ -725,20 +963,86 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, use_cuda_graph=True, ) @triton.jit -def attn_fwd(Q, K, V, bias, - Q_Descale, K_Descale, V_Descale, stride_q_descale_z, stride_k_descale_z, stride_v_descale_z, - SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, - stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, # Add seqused parameters - dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, - HK: tl.constexpr, ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, ACTUAL_BLOCK_DMODEL_V: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, - USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, BLOCK_M: tl.constexpr, - BLOCK_DMODEL_QK: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, NEEDS_SDMASK : tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, USE_SEQUSED: tl.constexpr): +def attn_fwd( + Q, + K, + V, + bias, + Q_Descale, + K_Descale, + V_Descale, + stride_q_descale_z, + stride_k_descale_z, + stride_v_descale_z, + SM_SCALE: tl.constexpr, + LSE, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + stride_az, + stride_ah, + stride_sz, + stride_sh, + stride_sm, + stride_sn, + stride_lse_z, + stride_lse_h, + stride_lse_m, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Add seqused parameters + dropout_p, + philox_seed, + philox_offset_base, + sd_mask, + dropout_mask, + alibi_slopes, + HQ: tl.constexpr, + HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, + ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_SCORES: tl.constexpr, + NEEDS_SDMASK: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_P_DESCALE: tl.constexpr, + USE_SEQUSED: tl.constexpr, +): # set params ACCUMULATOR_TYPE = tl.float32 @@ -753,8 +1057,8 @@ def attn_fwd(Q, K, V, bias, else: off_h_k = off_h_q # Determine if we need to mask the heads - PADDED_HEAD_QK: tl.constexpr = (ACTUAL_BLOCK_DMODEL_QK != BLOCK_DMODEL_QK) - PADDED_HEAD_V: tl.constexpr = (ACTUAL_BLOCK_DMODEL_V != BLOCK_DMODEL_V) + PADDED_HEAD_QK: tl.constexpr = ACTUAL_BLOCK_DMODEL_QK != BLOCK_DMODEL_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_BLOCK_DMODEL_V != BLOCK_DMODEL_V offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) @@ -765,24 +1069,36 @@ def attn_fwd(Q, K, V, bias, if IS_VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - + # If seqused is provided, use it to limit the actual sequence length if USE_SEQUSED: - actual_seqlen_q = tl.load(seqused_q + off_z) if seqused_q is not None else cu_seqlens_q_end - cu_seqlens_q_start - seqlen_q = tl.minimum(actual_seqlen_q, cu_seqlens_q_end - cu_seqlens_q_start) + actual_seqlen_q = ( + tl.load(seqused_q + off_z) + if seqused_q is not None + else cu_seqlens_q_end - cu_seqlens_q_start + ) + seqlen_q = tl.minimum( + actual_seqlen_q, cu_seqlens_q_end - cu_seqlens_q_start + ) else: seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - + # we have a one-size-fits-all grid in id(0). Some seqlens might be too small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) - + # If seqused is provided, use it to limit the actual sequence length for keys if USE_SEQUSED: - actual_seqlen_k = tl.load(seqused_k + off_z) if seqused_k is not None else cu_seqlens_k_end - cu_seqlens_k_start - seqlen_k = tl.minimum(actual_seqlen_k, cu_seqlens_k_end - cu_seqlens_k_start) + actual_seqlen_k = ( + tl.load(seqused_k + off_z) + if seqused_k is not None + else cu_seqlens_k_end - cu_seqlens_k_start + ) + seqlen_k = tl.minimum( + actual_seqlen_k, cu_seqlens_k_end - cu_seqlens_k_start + ) else: seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start else: @@ -796,19 +1112,35 @@ def attn_fwd(Q, K, V, bias, # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (off_h_k) # For MHA (GROUP_SIZE == 1), q_descale uses off_h_q (same as off_h_k) if GROUP_SIZE != 1: - q_descale = tl.load(Q_Descale + off_z * stride_q_descale_z + off_h_k) # MQA/GQA: broadcast using k/v head index + q_descale = tl.load( + Q_Descale + off_z * stride_q_descale_z + off_h_k + ) # MQA/GQA: broadcast using k/v head index else: - q_descale = tl.load(Q_Descale + off_z * stride_q_descale_z + off_h_q) # MHA: use q head index + q_descale = tl.load( + Q_Descale + off_z * stride_q_descale_z + off_h_q + ) # MHA: use q head index k_descale = tl.load(K_Descale + off_z * stride_k_descale_z + off_h_k) v_descale = tl.load(V_Descale + off_z * stride_v_descale_z + off_h_k) else: q_descale, k_descale, v_descale = 1.0, 1.0, 1.0 - # figure out masking pattern - n_front_skip_blocks, n_front_masked_blocks, n_full_blocks, n_back_masked_blocks, n_extra_tokens = compute_block_masking( - seqlen_k, seqlen_q, start_m, IS_CAUSAL, USE_SLIDING_WINDOW, - WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, BLOCK_M, BLOCK_N + ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + n_extra_tokens, + ) = compute_block_masking( + seqlen_k, + seqlen_q, + start_m, + IS_CAUSAL, + USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT, + BLOCK_M, + BLOCK_N, ) # ============================================================ @@ -820,18 +1152,33 @@ def attn_fwd(Q, K, V, bias, No K blocks visible - write zeros and exit. """ # Write zeros to output - o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_offset = ( + Out + + off_z * stride_oz + + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om + ) o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d_v[None, :] * stride_on - o_mask = (offs_m[:, None] < seqlen_q) + o_mask = offs_m[:, None] < seqlen_q if PADDED_HEAD_V: o_mask = o_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) - tl.store(o_ptrs, tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=Out.type.element_ty), mask=o_mask) - + tl.store( + o_ptrs, + tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=Out.type.element_ty), + mask=o_mask, + ) + # Write zeros to LSE - l_ptrs = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m + offs_m * stride_lse_m + l_ptrs = ( + LSE + + off_z * stride_lse_z + + off_h_q * stride_lse_h + + cu_seqlens_q_start * stride_lse_m + + offs_m * stride_lse_m + ) tl.store(l_ptrs, tl.zeros([BLOCK_M], dtype=tl.float32), mask=offs_m < seqlen_q) return - + # ============================================================ # NORMAL PROCESSING (Some K Blocks Visible) # ============================================================ @@ -839,19 +1186,30 @@ def attn_fwd(Q, K, V, bias, This program has visible K blocks to process. We'll use two calls to handle different block types efficiently. """ - + # Initialize for processing # Compute pointers for all the tensors used in this kernel. - q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_offset = ( + Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + ) q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qk - k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_offset = ( + K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + ) k_ptrs = k_offset + offs_d_qk[:, None] * stride_kk + offs_n[None, :] * stride_kn - v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_offset = ( + V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + ) v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d_v[None, :] * stride_vn if USE_BIAS: # Note: this might get large enough to overflow on some configs bias_offset = off_h_q * stride_bh - bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn + bias_ptrs = ( + bias + + bias_offset + + offs_m[:, None] * stride_bm + + offs_n[None, :] * stride_bn + ) else: bias_ptrs = None @@ -862,16 +1220,32 @@ def attn_fwd(Q, K, V, bias, alibi_slope = None if NEEDS_SDMASK: - sd_mask_offset = sd_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm - sd_mask_ptrs = sd_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + sd_mask_offset = ( + sd_mask + off_z * stride_sz + off_h_q * stride_sh + ) # + cu_seqlens_q_start * stride_sm + sd_mask_ptrs = ( + sd_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + ) else: sd_mask_ptrs = None if ENABLE_DROPOUT: - dropout_mask_offset = dropout_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm - dropout_mask_ptrs = dropout_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn - batch_philox_offset = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm - philox_ptrs = batch_philox_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + dropout_mask_offset = ( + dropout_mask + off_z * stride_sz + off_h_q * stride_sh + ) # + cu_seqlens_q_start * stride_sm + dropout_mask_ptrs = ( + dropout_mask_offset + + offs_m[:, None] * stride_sm + + offs_n[None, :] * stride_sn + ) + batch_philox_offset = ( + philox_offset_base + off_z * stride_sz + off_h_q * stride_sh + ) # + cu_seqlens_q_start * stride_sm + philox_ptrs = ( + batch_philox_offset + + offs_m[:, None] * stride_sm + + offs_n[None, :] * stride_sn + ) else: dropout_mask_ptrs = None philox_ptrs = 0 @@ -887,100 +1261,196 @@ def attn_fwd(Q, K, V, bias, q_ptrs_mask = q_ptrs_mask & (offs_d_qk[None, :] < ACTUAL_BLOCK_DMODEL_QK) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) - # ========== Process MASKED K Blocks in the front ========== # NOTE: we use USE_SLIDING_WINDOW as guard because the compiler will crash other wise. front masking is only for sliding window so that is fine. if n_front_masked_blocks > 0 and USE_SLIDING_WINDOW: block_min = n_front_skip_blocks * BLOCK_N block_max = (n_front_skip_blocks + n_front_masked_blocks) * BLOCK_N - + acc, l_i, m_i = _attn_fwd_mask( - acc, l_i, m_i, - q, k_ptrs, v_ptrs, bias_ptrs, - stride_kn, stride_vk, stride_bn, stride_sn, - start_m, seqlen_k, seqlen_q, - dropout_p, philox_seed, philox_ptrs, - sd_mask_ptrs, dropout_mask_ptrs, - offs_m, offs_n, offs_d_qk, offs_d_v, - block_min, # Start of front masked blocks - block_max, # End of front masked blocks - 0, # n_extra_tokens (0 for front blocks, only relevant for last block) - alibi_slope, - q_descale, k_descale, v_descale, IS_FP8, FP8_MAX, FP8_P_DESCALE, + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_ptrs, + sd_mask_ptrs, + dropout_mask_ptrs, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, # Start of front masked blocks + block_max, # End of front masked blocks + 0, # n_extra_tokens (0 for front blocks, only relevant for last block) + alibi_slope, + q_descale, + k_descale, + v_descale, + IS_FP8, + FP8_MAX, + FP8_P_DESCALE, IS_CAUSAL, - BLOCK_M, BLOCK_DMODEL_QK, BLOCK_DMODEL_V, BLOCK_N, + BLOCK_M, + BLOCK_DMODEL_QK, + BLOCK_DMODEL_V, + BLOCK_N, PRE_LOAD_V, - ENABLE_DROPOUT, PADDED_HEAD_QK, PADDED_HEAD_V, - ACTUAL_BLOCK_DMODEL_QK, ACTUAL_BLOCK_DMODEL_V, SM_SCALE, - USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - RETURN_SCORES=RETURN_SCORES, - USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, - WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, + ENABLE_DROPOUT, + PADDED_HEAD_QK, + PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, + ACTUAL_BLOCK_DMODEL_V, + SM_SCALE, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + RETURN_SCORES=RETURN_SCORES, + USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, - ACCUMULATOR_TYPE=ACCUMULATOR_TYPE + ACCUMULATOR_TYPE=ACCUMULATOR_TYPE, ) - + # ========== Process FULL K Blocks (Fast Path) ========== if n_full_blocks > 0: block_min = (n_front_skip_blocks + n_front_masked_blocks) * BLOCK_N - block_max = (n_front_skip_blocks + n_front_masked_blocks + n_full_blocks) * BLOCK_N - + block_max = ( + n_front_skip_blocks + n_front_masked_blocks + n_full_blocks + ) * BLOCK_N + acc, l_i, m_i = _attn_fwd_no_mask( - acc, l_i, m_i, - q, k_ptrs, v_ptrs, bias_ptrs, - stride_kn, stride_vk, stride_bn, stride_sn, - start_m, seqlen_k, seqlen_q, - dropout_p, philox_seed, philox_ptrs, - sd_mask_ptrs, dropout_mask_ptrs, - offs_m, offs_n, offs_d_qk, offs_d_v, - block_min, # Start of range: 0 - block_max, # End of range: n_full_blocks * BLOCK_N + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_ptrs, + sd_mask_ptrs, + dropout_mask_ptrs, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, # Start of range: 0 + block_max, # End of range: n_full_blocks * BLOCK_N alibi_slope, - q_descale, k_descale, v_descale, IS_FP8, FP8_MAX, FP8_P_DESCALE, - BLOCK_M, BLOCK_DMODEL_QK, BLOCK_DMODEL_V, BLOCK_N, + q_descale, + k_descale, + v_descale, + IS_FP8, + FP8_MAX, + FP8_P_DESCALE, + BLOCK_M, + BLOCK_DMODEL_QK, + BLOCK_DMODEL_V, + BLOCK_N, PRE_LOAD_V, - ENABLE_DROPOUT, PADDED_HEAD_QK, PADDED_HEAD_V, - ACTUAL_BLOCK_DMODEL_QK, ACTUAL_BLOCK_DMODEL_V, SM_SCALE, - USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE + ENABLE_DROPOUT, + PADDED_HEAD_QK, + PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, + ACTUAL_BLOCK_DMODEL_V, + SM_SCALE, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + RETURN_SCORES=RETURN_SCORES, + ACCUMULATOR_TYPE=ACCUMULATOR_TYPE, ) - + # ========== Process MASKED K Blocks in the back ========== if n_back_masked_blocks > 0: - block_min = (n_front_skip_blocks + n_front_masked_blocks + n_full_blocks) * BLOCK_N - block_max = (n_front_skip_blocks + n_front_masked_blocks + n_full_blocks + n_back_masked_blocks) * BLOCK_N - + block_min = ( + n_front_skip_blocks + n_front_masked_blocks + n_full_blocks + ) * BLOCK_N + block_max = ( + n_front_skip_blocks + + n_front_masked_blocks + + n_full_blocks + + n_back_masked_blocks + ) * BLOCK_N + acc, l_i, m_i = _attn_fwd_mask( - acc, l_i, m_i, - q, k_ptrs, v_ptrs, bias_ptrs, - stride_kn, stride_vk, stride_bn, stride_sn, - start_m, seqlen_k, seqlen_q, - dropout_p, philox_seed, philox_ptrs, - sd_mask_ptrs, dropout_mask_ptrs, - offs_m, offs_n, offs_d_qk, offs_d_v, - block_min, # Start of range: n_full_blocks * BLOCK_N - block_max, # End of range: n_visible_k_blocks * BLOCK_N - n_extra_tokens, # Padding tokens in last block - alibi_slope, - q_descale, k_descale, v_descale, IS_FP8, FP8_MAX, FP8_P_DESCALE, - IS_CAUSAL, # Use actual causal flag - BLOCK_M, BLOCK_DMODEL_QK, BLOCK_DMODEL_V, BLOCK_N, + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_ptrs, + sd_mask_ptrs, + dropout_mask_ptrs, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, # Start of range: n_full_blocks * BLOCK_N + block_max, # End of range: n_visible_k_blocks * BLOCK_N + n_extra_tokens, # Padding tokens in last block + alibi_slope, + q_descale, + k_descale, + v_descale, + IS_FP8, + FP8_MAX, + FP8_P_DESCALE, + IS_CAUSAL, # Use actual causal flag + BLOCK_M, + BLOCK_DMODEL_QK, + BLOCK_DMODEL_V, + BLOCK_N, PRE_LOAD_V, - ENABLE_DROPOUT, PADDED_HEAD_QK, PADDED_HEAD_V, - ACTUAL_BLOCK_DMODEL_QK, ACTUAL_BLOCK_DMODEL_V, SM_SCALE, - USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, - RETURN_SCORES=RETURN_SCORES, - USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, - WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, + ENABLE_DROPOUT, + PADDED_HEAD_QK, + PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, + ACTUAL_BLOCK_DMODEL_V, + SM_SCALE, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + RETURN_SCORES=RETURN_SCORES, + USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, - ACCUMULATOR_TYPE=ACCUMULATOR_TYPE + ACCUMULATOR_TYPE=ACCUMULATOR_TYPE, ) # ============================================================ # EPILOGUE # ============================================================ # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. - # Instead of directly computing 1/l_i which can be inf, + # Instead of directly computing 1/l_i which can be inf, # we check for the invalid case first if USE_SLIDING_WINDOW: # For rows where m_i is still -inf, no keys were valid @@ -1036,33 +1506,40 @@ def attn_fwd(Q, K, V, bias, # Causal mask (X = can attend, . = cannot): # K0 K1 K2 K3 # Q0 . . . . <- All masked, would give NaN - # Q1 . . . . <- All masked, would give NaN + # Q1 . . . . <- All masked, would give NaN # Q2 X . . . <- First valid row # Q3 X X . . # Q4 X X X . # Q5 X X X X causal_start_idx = seqlen_q - seqlen_k start_m_idx = start_m * BLOCK_M - + # Create mask for rows that need zeroing row_indices = start_m_idx + tl.arange(0, BLOCK_M) causal_mask = row_indices < causal_start_idx - + # Zero out both acc and LSE for these rows if causal_start_idx > start_m_idx: end_m_idx = (start_m + 1) * BLOCK_M if causal_start_idx < end_m_idx: # This block contains the boundary - need to mask acc - out_mask_boundary = tl.full((BLOCK_DMODEL_V, ), causal_start_idx, dtype=tl.int32) + out_mask_boundary = tl.full( + (BLOCK_DMODEL_V,), causal_start_idx, dtype=tl.int32 + ) out_ptrs_mask = row_indices[:, None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - + # Zero out LSE for rows above diagonal softmax_lse = tl.where(causal_mask, 0.0, softmax_lse) # write back LSE(Log Sum Exponents), the log of the normalization constant - l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m + l_offset = ( + LSE + + off_z * stride_lse_z + + off_h_q * stride_lse_h + + cu_seqlens_q_start * stride_lse_m + ) l_ptrs = l_offset + offs_m * stride_lse_m # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. @@ -1070,14 +1547,16 @@ def attn_fwd(Q, K, V, bias, end_m_idx = (start_m + 1) * BLOCK_M overflow_size = end_m_idx - seqlen_q if overflow_size > 0: - boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary tl.store(l_ptrs, softmax_lse, mask=l_ptrs_mask) else: tl.store(l_ptrs, softmax_lse) # write back O - o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_offset = ( + Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + ) o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d_v[None, :] * stride_on o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL_V], 1, dtype=tl.int1) if overflow_size > 0: @@ -1087,127 +1566,246 @@ def attn_fwd(Q, K, V, bias, tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) -def attention_prefill_forward_triton_impl( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - window_size_left: int, - window_size_right: int, - bias: Optional[torch.Tensor], - layout: Literal["bshd", "bhsd", "thd"], - # varlen - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlens_q: int, - max_seqlens_k: int, - # dropout - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - # misc - return_softmax: bool, - use_exp2: bool, - # fp8 - q_descale: Optional[torch.Tensor], - k_descale: Optional[torch.Tensor], - v_descale: Optional[torch.Tensor], - # seqused for FA v3 - seqused_q: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, + +def attention_forward_prefill_triton_impl( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + window_size_left: int, + window_size_right: int, + bias: Optional[torch.Tensor], + layout: Literal["bshd", "bhsd", "thd"], + # varlen + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlens_q: int, + max_seqlens_k: int, + # dropout + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + # misc + return_softmax: bool, + use_exp2: bool, + # fp8 + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + # seqused for FA v3 + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + # rotary (optional) + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + rotary_interleaved: bool = False, + seqlens_rotary: Optional[torch.Tensor] = None, ): # get params, strides and shape IS_VARLEN = layout == "thd" # common assertions - assert 0.0 <= dropout_p <= 1.0, f"dropout_p must be between 0 and 1, got {dropout_p}" - assert q.device == k.device == v.device == o.device, \ - f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}" + assert ( + 0.0 <= dropout_p <= 1.0 + ), f"dropout_p must be between 0 and 1, got {dropout_p}" + assert ( + q.device == k.device == v.device == o.device + ), f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}" assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype" current_device = torch.cuda.current_device() - assert q.is_cuda and q.device.index == current_device, f"Device mismatch: Kernel will launch on cuda:{current_device}, but tensors are on {q.device}" - + assert ( + q.is_cuda and q.device.index == current_device + ), f"Device mismatch: Kernel will launch on cuda:{current_device}, but tensors are on {q.device}" + # get shapes and strides if IS_VARLEN: # shape total_seqlen_q, nheads_q, head_size_q = q.shape total_seqlen_k, nheads_k, head_size_k = k.shape total_seqlen_v, nheads_v, head_size_v = v.shape - + # assert shapes - assert cu_seqlens_q is not None, "cu_seqlens_q must be provided for varlen layout" - assert cu_seqlens_k is not None, "cu_seqlens_k must be provided for varlen layout" - assert max_seqlens_q is not None and max_seqlens_q > 0, "max_seqlens_q must be provided and positive for varlen layout" - assert max_seqlens_k is not None and max_seqlens_k > 0, "max_seqlens_k must be provided and positive for varlen layout" - + assert ( + cu_seqlens_q is not None + ), "cu_seqlens_q must be provided for varlen layout" + assert ( + cu_seqlens_k is not None + ), "cu_seqlens_k must be provided for varlen layout" + assert ( + max_seqlens_q is not None and max_seqlens_q > 0 + ), "max_seqlens_q must be provided and positive for varlen layout" + assert ( + max_seqlens_k is not None and max_seqlens_k > 0 + ), "max_seqlens_k must be provided and positive for varlen layout" + # assert head dimensions - assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" - assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" - assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" - + assert ( + head_size_q == head_size_k + ), f"head sizes must match: q={head_size_q}, k={head_size_k}" + assert ( + nheads_k == nheads_v + ), f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert ( + nheads_q % nheads_k == 0 + ), f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + # assert output shapes - assert o.shape == (total_seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" - + assert o.shape == ( + total_seqlen_q, + nheads_q, + head_size_v, + ), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" + # assert cu_seqlens - assert cu_seqlens_q.dtype == torch.int32, f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" - assert cu_seqlens_k.dtype == torch.int32, f"cu_seqlens_k must be int32, got {cu_seqlens_k.dtype}" + assert ( + cu_seqlens_q.dtype == torch.int32 + ), f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" + assert ( + cu_seqlens_k.dtype == torch.int32 + ), f"cu_seqlens_k must be int32, got {cu_seqlens_k.dtype}" assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" assert cu_seqlens_k[0] == 0, "cu_seqlens_k must start with 0" - assert cu_seqlens_q[-1] == total_seqlen_q, f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} != total_seqlen_q {total_seqlen_q}" - assert cu_seqlens_k[-1] == total_seqlen_k, f"cu_seqlens_k[-1] {cu_seqlens_k[-1]} != total_seqlen_k {total_seqlen_k}" - + assert ( + cu_seqlens_q[-1] == total_seqlen_q + ), f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} != total_seqlen_q {total_seqlen_q}" + assert ( + cu_seqlens_k[-1] == total_seqlen_k + ), f"cu_seqlens_k[-1] {cu_seqlens_k[-1]} != total_seqlen_k {total_seqlen_k}" + # set vars batch = len(cu_seqlens_q) - 1 head_size_qk = head_size_q - + # softmax_lse shape - softmax_lse = torch.zeros((nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32) - + softmax_lse = torch.zeros( + (nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32 + ) + # strides - stride_qb, stride_qh, stride_qm, stride_qd = 0, q.stride(1), q.stride(0), q.stride(2) - stride_kb, stride_kh, stride_kn, stride_kd = 0, k.stride(1), k.stride(0), k.stride(2) - stride_vb, stride_vh, stride_vn, stride_vd = 0, v.stride(1), v.stride(0), v.stride(2) - stride_ob, stride_oh, stride_om, stride_od = 0, o.stride(1), o.stride(0), o.stride(2) - stride_lse_z, stride_lse_h, stride_lse_m = 0, softmax_lse.stride(0), softmax_lse.stride(1) + stride_qb, stride_qh, stride_qm, stride_qd = ( + 0, + q.stride(1), + q.stride(0), + q.stride(2), + ) + stride_kb, stride_kh, stride_kn, stride_kd = ( + 0, + k.stride(1), + k.stride(0), + k.stride(2), + ) + stride_vb, stride_vh, stride_vn, stride_vd = ( + 0, + v.stride(1), + v.stride(0), + v.stride(2), + ) + stride_ob, stride_oh, stride_om, stride_od = ( + 0, + o.stride(1), + o.stride(0), + o.stride(2), + ) + stride_lse_z, stride_lse_h, stride_lse_m = ( + 0, + softmax_lse.stride(0), + softmax_lse.stride(1), + ) else: # shapes batch_q, seqlen_q, nheads_q, head_size_q = q.shape batch_k, seqlen_k, nheads_k, head_size_k = k.shape batch_v, seqlen_v, nheads_v, head_size_v = v.shape - + # assert batch dimensions - assert batch_q == batch_k == batch_v, f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" - + assert ( + batch_q == batch_k == batch_v + ), f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" + # assert head dimensions - assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" - assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" - assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" - + assert ( + head_size_q == head_size_k + ), f"head sizes must match: q={head_size_q}, k={head_size_k}" + assert ( + nheads_k == nheads_v + ), f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert ( + nheads_q % nheads_k == 0 + ), f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + # assert sequence lengths - assert seqlen_k == seqlen_v, f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" - + assert ( + seqlen_k == seqlen_v + ), f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" + # assert output shapes - assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected {(batch_q, seqlen_q, nheads_q, head_size_v)}" - + assert o.shape == ( + batch_q, + seqlen_q, + nheads_q, + head_size_v, + ), f"o shape {o.shape} != expected {(batch_q, seqlen_q, nheads_q, head_size_v)}" + # set vars batch = batch_q head_size_qk = head_size_q max_seqlens_q = seqlen_q max_seqlens_k = seqlen_k - + # softmax_lse shape - softmax_lse = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) - + softmax_lse = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + # strides - stride_qb, stride_qh, stride_qm, stride_qd = q.stride(0), q.stride(2), q.stride(1), q.stride(3) - stride_kb, stride_kh, stride_kn, stride_kd = k.stride(0), k.stride(2), k.stride(1), k.stride(3) - stride_vb, stride_vh, stride_vn, stride_vd = v.stride(0), v.stride(2), v.stride(1), v.stride(3) - stride_ob, stride_oh, stride_om, stride_od = o.stride(0), o.stride(2), o.stride(1), o.stride(3) + stride_qb, stride_qh, stride_qm, stride_qd = ( + q.stride(0), + q.stride(2), + q.stride(1), + q.stride(3), + ) + stride_kb, stride_kh, stride_kn, stride_kd = ( + k.stride(0), + k.stride(2), + k.stride(1), + k.stride(3), + ) + stride_vb, stride_vh, stride_vn, stride_vd = ( + v.stride(0), + v.stride(2), + v.stride(1), + v.stride(3), + ) + stride_ob, stride_oh, stride_om, stride_od = ( + o.stride(0), + o.stride(2), + o.stride(1), + o.stride(3), + ) stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() + # apply rotary embeddings + if rotary_cos is not None and rotary_sin is not None: + if IS_VARLEN: + raise NotImplementedError( + "Rotary embeddings with varlen (thd layout) prefill are not implemented yet." + ) + seqlen_offsets = seqlens_rotary if seqlens_rotary is not None else 0 + local = (window_size_left != -1) or (window_size_right != -1) + q, _ = apply_rotary( + q, + None, + rotary_cos, + rotary_sin, + causal=causal, + local=local, + interleaved=rotary_interleaved, + seqlen_offsets=seqlen_offsets, + ) + # fp8 setup and assertions IS_FP8 = is_fp8(q) if IS_FP8: @@ -1218,39 +1816,56 @@ def attention_prefill_forward_triton_impl( # Check and create default descale tensors if not provided if (q_descale is None) or (k_descale is None) or (v_descale is None): import warnings - warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) + + warnings.warn( + "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", + UserWarning, + ) # Create default descale tensors if not provided if q_descale is None: - q_descale = torch.ones(batch, nheads_q, dtype=torch.float32, device=q.device) + q_descale = torch.ones( + batch, nheads_q, dtype=torch.float32, device=q.device + ) if k_descale is None: - k_descale = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) + k_descale = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) if v_descale is None: - v_descale = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) - + v_descale = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + # o should be fp32 or fp16/bf16 - assert o.dtype in [torch.float16, torch.bfloat16, torch.float32], \ - f"Output tensor o must be fp16, bf16, or fp32 when using fp8, got {o.dtype}" + assert o.dtype in [ + torch.float16, + torch.bfloat16, + torch.float32, + ], f"Output tensor o must be fp16, bf16, or fp32 when using fp8, got {o.dtype}" stride_q_descale_z = q_descale.stride(0) if q_descale is not None else 0 stride_k_descale_z = k_descale.stride(0) if k_descale is not None else 0 stride_v_descale_z = v_descale.stride(0) if v_descale is not None else 0 - + if DEBUG: print(f"FP8 path triggered in fwd_prefill.py") else: FP8_MAX = None q_descale = k_descale = v_descale = None stride_q_descale_z = stride_k_descale_z = stride_v_descale_z = None - + # check output dtype matches input dtype when not using fp8 - assert o.dtype == q.dtype, f"Output dtype {o.dtype} must match input dtype {q.dtype} when not using fp8" - + assert ( + o.dtype == q.dtype + ), f"Output dtype {o.dtype} must match input dtype {q.dtype} when not using fp8" + # check features - use_sliding_window = window_size_left != -1 or window_size_right!= -1 - use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) + use_sliding_window = window_size_left != -1 or window_size_right != -1 + use_alibi, (stride_az, stride_ah) = ( + (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) + ) # NOTE: a large bias tensor leads to overflow during pointer arithmetic - if (bias is not None): - assert (bias.numel() < 2**31) + if bias is not None: + assert bias.numel() < 2**31 # Get closest power of 2 over or equal to 32 for both QK and V dimensions padded_d_model_qk = 1 << (head_size_qk - 1).bit_length() @@ -1266,48 +1881,122 @@ def attention_prefill_forward_triton_impl( # only. This return holds no useful output aside from debugging. NEEDS_SDMASK = (dropout_p > 0.0) or return_softmax if NEEDS_SDMASK: - sd_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, - dtype=torch.float32) + sd_mask = torch.zeros( + (batch, nheads_q, max_seqlens_q, max_seqlens_k), + device=q.device, + dtype=torch.float32, + ) if DROPOUT_USE_PYTORCH: - dropout_mask = create_dropout_mask(dropout_p, (batch, nheads_q, max_seqlens_q, max_seqlens_k), seed = philox_seed) + dropout_mask = create_dropout_mask( + dropout_p, + (batch, nheads_q, max_seqlens_q, max_seqlens_k), + seed=philox_seed, + ) else: - dropout_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, - dtype=torch.float32) - stride_sz, stride_sh, stride_sm, stride_sn = (sd_mask.stride(0), sd_mask.stride(1), sd_mask.stride(2), sd_mask.stride(3)) + dropout_mask = torch.zeros( + (batch, nheads_q, max_seqlens_q, max_seqlens_k), + device=q.device, + dtype=torch.float32, + ) + stride_sz, stride_sh, stride_sm, stride_sn = ( + sd_mask.stride(0), + sd_mask.stride(1), + sd_mask.stride(2), + sd_mask.stride(3), + ) else: sd_mask = None dropout_mask = None stride_sz, stride_sh, stride_sm, stride_sn = (0, 0, 0, 0) if bias is not None: - stride_bz, stride_bh, stride_bm, stride_bn = (bias.stride(0), bias.stride(1),bias.stride(2), - bias.stride(3)) + stride_bz, stride_bh, stride_bm, stride_bn = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) else: stride_bz, stride_bh, stride_bm, stride_bn = (0, 0, 0, 0) # launch kernel - grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META['BLOCK_M'])) - attn_fwd[grid](q, k, v, bias, - q_descale, k_descale, v_descale, stride_q_descale_z, stride_k_descale_z, stride_v_descale_z, - sm_scale, softmax_lse, o, - stride_qb, stride_qh, stride_qm, stride_qd, - stride_kb, stride_kh, stride_kn, stride_kd, - stride_vb, stride_vh, stride_vn, stride_vd, - stride_ob, stride_oh, stride_om, stride_od, - stride_bz, stride_bh, stride_bm, stride_bn, - stride_az, stride_ah, - stride_sz, stride_sh, stride_sm, stride_sn, - stride_lse_z, stride_lse_h, stride_lse_m, - cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, # Pass seqused tensors - dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, - HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL_QK=head_size_qk, ACTUAL_BLOCK_DMODEL_V=head_size_v, MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, - USE_SLIDING_WINDOW=use_sliding_window, WINDOW_SIZE_LEFT=window_size_left, WINDOW_SIZE_RIGHT=window_size_right, - IS_VARLEN=IS_VARLEN, - BLOCK_DMODEL_QK=padded_d_model_qk, BLOCK_DMODEL_V=padded_d_model_v, USE_BIAS=False if bias is None else True, - USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, NEEDS_SDMASK=NEEDS_SDMASK, - IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_P_DESCALE=False, - USE_SEQUSED=(seqused_q is not None or seqused_k is not None)) # Add flag for seqused - - return softmax_lse, sd_mask if return_softmax else None \ No newline at end of file + grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META["BLOCK_M"])) + attn_fwd[grid]( + q, + k, + v, + bias, + q_descale, + k_descale, + v_descale, + stride_q_descale_z, + stride_k_descale_z, + stride_v_descale_z, + sm_scale, + softmax_lse, + o, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + stride_az, + stride_ah, + stride_sz, + stride_sh, + stride_sm, + stride_sn, + stride_lse_z, + stride_lse_h, + stride_lse_m, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Pass seqused tensors + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + sd_mask=sd_mask, + dropout_mask=dropout_mask, + alibi_slopes=alibi_slopes, + HQ=nheads_q, + HK=nheads_k, + ACTUAL_BLOCK_DMODEL_QK=head_size_qk, + ACTUAL_BLOCK_DMODEL_V=head_size_v, + MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, + IS_CAUSAL=causal, + USE_SLIDING_WINDOW=use_sliding_window, + WINDOW_SIZE_LEFT=window_size_left, + WINDOW_SIZE_RIGHT=window_size_right, + IS_VARLEN=IS_VARLEN, + BLOCK_DMODEL_QK=padded_d_model_qk, + BLOCK_DMODEL_V=padded_d_model_v, + USE_BIAS=False if bias is None else True, + USE_ALIBI=use_alibi, + ENABLE_DROPOUT=dropout_p > 0.0, + USE_EXP2=use_exp2, + RETURN_SCORES=return_softmax, + NEEDS_SDMASK=NEEDS_SDMASK, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_P_DESCALE=False, + USE_SEQUSED=(seqused_q is not None or seqused_k is not None), + ) # Add flag for seqused + + return softmax_lse, sd_mask if return_softmax else None diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py deleted file mode 100755 index a8ca54a7ec3..00000000000 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ /dev/null @@ -1,889 +0,0 @@ -import torch -import math -from typing import Literal, Optional, Union -from .utils import compute_alibi_tensor_ref - -DEBUG = False -DEBUG_CORE = False - -def attention_forward_core_ref_impl( - q, k, v, sm_scale, causal, window_size_left, window_size_right, - dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2, - cache_seqlens=None, block_table=None, paged_kv_block_size=None -): - if DEBUG_CORE: - print() - print("attention_forward_core_ref_impl") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("sm_scale:", sm_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("dropout_p:", dropout_p) - print("philox_seed:", philox_seed) - print("philox_offset:", philox_offset) - print("use_exp2:", use_exp2) - print("cache_seqlens:", cache_seqlens) - print("block_table:", block_table) - print("paged_kv_block_size:", paged_kv_block_size) - - # cast to float32 - q = q.to(torch.float32) - - # Check if we're in paged KV mode - is_paged = block_table is not None and paged_kv_block_size is not None - - if False: # Debug paged attention (disabled for production) - print(f"\n=== attention_forward_core_ref_impl DEBUG ===") - print(f"is_paged: {is_paged}") - print(f"block_table: {block_table.shape if block_table is not None else None}") - print(f"paged_kv_block_size: {paged_kv_block_size}") - if is_paged: - print(f"k shape (paged): {k.shape}") - print(f"v shape (paged): {v.shape}") - print(f"cache_seqlens: {cache_seqlens}") - - if is_paged: - # In paged mode, k and v are [num_blocks, block_size, nheads_k, head_dim] - # We'll compute attention on-the-fly without reconstructing - nheads_q = q.shape[0] - L_q = q.shape[1] - head_dim = q.shape[2] - - # Get number of KV heads from the cache - nheads_k = k.shape[2] # k shape: [num_blocks, block_size, nheads_k, head_dim] - - # Handle MQA/GQA - assert nheads_q % nheads_k == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_k ({nheads_k})" - group_size = nheads_q // nheads_k - - # Determine the actual KV sequence length from cache_seqlens - L_k = cache_seqlens if isinstance(cache_seqlens, int) else cache_seqlens.item() - - if False: # Debug disabled - print(f"L_q: {L_q}, L_k: {L_k}, nheads_q: {nheads_q}, nheads_k: {nheads_k}, group_size: {group_size}, head_dim: {head_dim}") - print(f"block_table contents: {block_table if block_table is not None else 'None'}") - - # Initialize attention scores - attention_scores = torch.zeros((nheads_q, L_q, L_k), dtype=torch.float32, device=q.device) - - # Compute attention scores on-the-fly by accessing blocks directly - for kv_pos in range(L_k): - # Calculate which block and position within block - block_idx = kv_pos // paged_kv_block_size - within_block_idx = kv_pos % paged_kv_block_size - - # Get the physical block number from block_table - # block_table is [1, num_blocks] for single batch in core function - if block_table.dim() == 2: - physical_block = block_table[0, block_idx].item() - else: - physical_block = block_table[block_idx].item() - - # Debug output disabled - # if kv_pos == 0: - # print(f"First KV access: block_idx={block_idx}, within_block={within_block_idx}, physical_block={physical_block}") - # print(f"k_vec shape will be: {k[physical_block, within_block_idx, :, :].shape}") - - # Access k values directly from paged cache - # k shape: [num_blocks, block_size, nheads_k, head_dim] - k_vec = k[physical_block, within_block_idx, :, :].to(torch.float32) # [nheads_k, head_dim] - - # For GQA/MQA, we need to repeat k_vec for each group - if group_size > 1: - # Expand k_vec to match query heads - # k_vec: [nheads_k, head_dim] -> [nheads_q, head_dim] - k_vec = k_vec.repeat_interleave(group_size, dim=0) - - # Compute dot product with all query positions - # q is [nheads_q, L_q, head_dim], k_vec is [nheads_q, head_dim] - # Result should be [nheads_q, L_q] for this kv_pos - attention_scores[:, :, kv_pos] = torch.sum(q * k_vec.unsqueeze(1), dim=-1) - - # Keep k and v in original format for later v computation - k_paged = k - v_paged = v - - # Debug output disabled - # print(f"attention_scores computed shape: {attention_scores.shape}") - # print(f"attention_scores sample values: {attention_scores[0, 0, :5]}") - else: - # Standard non-paged mode - k = k.to(torch.float32) - v = v.to(torch.float32) - - # get seqlens - L_q, L_k = q.shape[1], k.shape[1] - - # Compute attention scores - attention_scores = torch.matmul(q, k.transpose(-2, -1)) - if DEBUG_CORE: - print("attention_scores:", attention_scores, attention_scores.shape) - - # Scale scores - attention_scaled_scores = sm_scale * attention_scores - if DEBUG_CORE: - print("attention_scaled_scores:", attention_scaled_scores, attention_scaled_scores.shape) - - # Apply ALiBi if slopes are provided - if alibi_slopes is not None: - if cache_seqlens is not None: - # DECODE MODE: Special ALiBi handling - # In decode mode, k has shape [nheads, max_cache_len, head_dim] - # but only cache_seqlens positions are valid - - # The test's attn_bias_from_alibi_slopes uses this formula: - # relative_pos = torch.abs(row_idx + sk - sq - col_idx) - # where sk = actual valid key length, sq = query length - - row_idx = torch.arange(L_q, device=q.device, dtype=torch.float32).unsqueeze(1) - col_idx = torch.arange(L_k, device=q.device, dtype=torch.float32).unsqueeze(0) - - # Compute relative positions - # cache_seqlens is the actual number of valid keys (sk in the test) - # L_q is the query sequence length (sq in the test) - relative_pos = torch.abs(row_idx + cache_seqlens - L_q - col_idx) - - # Apply slopes - if alibi_slopes.dim() == 1: - # Shape: [nheads] -> [nheads, 1, 1] - alibi_slopes_expanded = alibi_slopes.view(-1, 1, 1) - else: - # Already has batch dimension - alibi_slopes_expanded = alibi_slopes - - alibi_bias = -alibi_slopes_expanded * relative_pos - - if DEBUG_CORE: - print(f"Decode ALiBi: cache_seqlens={cache_seqlens}, L_q={L_q}, L_k={L_k}") - print(f"relative_pos shape: {relative_pos.shape}") - print(f"alibi_bias shape: {alibi_bias.shape}") - else: - if DEBUG_CORE: - print("alibi_slopes:", alibi_slopes, alibi_slopes.shape) - alibi_bias = compute_alibi_tensor_ref(alibi_slopes, L_q, L_k) - if DEBUG_CORE: - print("alibi_bias:", alibi_bias, alibi_bias.shape) - alibi_bias = alibi_bias.reshape(-1, L_q, L_k) - if DEBUG_CORE: - print("alibi_bias_flat:", alibi_bias, alibi_bias.shape) - - attention_scaled_scores = attention_scaled_scores + alibi_bias - if DEBUG_CORE: - print("attention_scaled_scores after alibi:", attention_scaled_scores, attention_scaled_scores.shape) - - # Apply masks - row_idx = torch.arange(L_q, device=q.device).unsqueeze(1) - col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) - - if cache_seqlens is not None: - # We're in decode mode with a KV cache - # k and v are full allocated size, but only cache_seqlens positions are valid - - # Create a mask for valid cache positions - cache_mask = col_idx < cache_seqlens - - # Use cache_seqlens for offset calculation to match test's construct_local_mask - # which uses key_padding_mask.sum() as the sequence length - col_offset = cache_seqlens - L_q - - if DEBUG_CORE: - print(f"Cache mode: valid_len={cache_seqlens}, L_k={L_k}") - print(f"Using col_offset={col_offset} based on valid cache length") - else: - # Calculate offset for when seqlen_q != seqlen_k - # This offset aligns query positions to key positions - # When L_q < L_k, offset is positive, meaning query i maps to key position (i + offset) - # This is consistent with construct_local_mask in the tests which uses (sk - sq) - col_offset = L_k - L_q - cache_mask = None - - mask_applied = False - if causal and (window_size_left, window_size_right) == (-1, -1): - # Pure causal: ensure query doesn't attend to future keys - # With offset, query i can attend to keys up to position (i + col_offset) - mask = row_idx >= (col_idx - col_offset) - mask_applied = True - if DEBUG_CORE: - print("causal_mask:", mask) - elif (window_size_left, window_size_right) != (-1, -1): - # Handle the case where window sizes exceed sequence length - if window_size_left >= L_k: - window_size_left = -1 # No left limit - if window_size_right >= L_k: - window_size_right = -1 # No right limit - - if causal: - # Causal + sliding window: ensure we don't attend to future - window_size_right = min(window_size_right, 0) if window_size_right != -1 else 0 - - # Create sliding window mask - # Each query at position i attends to keys in [i + offset - left, i + offset + right] - if window_size_left == -1 and window_size_right == -1: - # No window restriction - mask = torch.ones((L_q, L_k), dtype=torch.bool, device=q.device) - else: - mask = torch.ones((L_q, L_k), dtype=torch.bool, device=q.device) - if window_size_left != -1: - # Each query at position i attends to keys from position (i - left) accounting for offset - mask = mask & (col_idx >= (row_idx + col_offset - window_size_left)) - if window_size_right != -1: - # Each query at position i attends to keys up to position (i + right) accounting for offset - mask = mask & (col_idx <= (row_idx + col_offset + window_size_right)) - - # Apply causal constraint - if causal: - causal_mask = row_idx >= (col_idx - col_offset) - mask = mask & causal_mask - - mask_applied = True - if DEBUG_CORE: - print(f"sliding_window_mask (left={window_size_left}, right={window_size_right}):", mask) - - # Apply cache mask if needed - if cache_mask is not None: - if mask_applied: - mask = mask & cache_mask - else: - mask = cache_mask - mask_applied = True - - # Apply the mask if created - if mask_applied: - attention_scaled_scores = attention_scaled_scores.masked_fill( - torch.logical_not(mask.unsqueeze(0)), float('-inf') - ) - if DEBUG_CORE: - print("attention_scaled_scores after masking:", attention_scaled_scores, attention_scaled_scores.shape) - - # Compute max for numerical stability - max_scores = torch.max(attention_scaled_scores, dim=-1, keepdim=True)[0] - if DEBUG_CORE: - print("max_scores:", max_scores, max_scores.shape) - if mask_applied: - # Replace -inf in max_scores with zeros to avoid NaN in subtraction - max_scores = torch.where( - torch.isinf(max_scores), torch.zeros_like(max_scores), max_scores - ) - if DEBUG_CORE: - print("max_scores after mask handling:", max_scores, max_scores.shape) - - # Shift scores - attention_shifted_scaled_scores = attention_scaled_scores - max_scores - if DEBUG_CORE: - print("attention_shifted_scaled_scores:", attention_shifted_scaled_scores, attention_shifted_scaled_scores.shape) - - # Exponentiate - if use_exp2: - RCP_LN = 1 / math.log(2) - exp_scores = torch.exp2(RCP_LN * attention_shifted_scaled_scores) - else: - exp_scores = torch.exp(attention_shifted_scaled_scores) - - if DEBUG_CORE: - print("exp_scores:", exp_scores, exp_scores.shape) - - # Sum of exponentials - sum_exp_scores = torch.sum(exp_scores, dim=-1, keepdim=True) - if DEBUG_CORE: - print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape) - if mask_applied: - # if sum of exp scores is 0.0 it means scores where -inf, we cannot compute softmax and softmax_lse. Setting to 1 deals with -inf case cleanly - sum_exp_scores = torch.where( - sum_exp_scores == 0, - torch.ones_like(sum_exp_scores), - sum_exp_scores - ) - if DEBUG_CORE: - print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape) - - # Compute softmax probabilities - p = exp_scores / sum_exp_scores - - if DEBUG_CORE: - print("softmax:", p, p.shape) - - # apply dropout if specified - if dropout_p > 0.0: - rand_vals = torch.rand(p.shape, generator=torch.Generator(device=p.device).manual_seed(philox_seed), device=p.device, dtype=p.dtype) - dropout_mask, dropout_scale = rand_vals > dropout_p, (1.0 / (1 - dropout_p)) - if DEBUG_CORE: - print("dropout_scale:", dropout_scale) - print("dropout_mask:", dropout_mask) - # Apply dropout mask and scale - # Set -1 for dropped positions and 1 for kept positions in exp_scores - sd_mask = torch.where(dropout_mask, exp_scores, -exp_scores) - p = torch.where(dropout_mask, p , torch.zeros_like(p)) * dropout_scale - if DEBUG_CORE: - print("softmax after dropout:", p) - print("sd_mask:", sd_mask) - else: - sd_mask = exp_scores - - # Compute log-sum-exp - if use_exp2: - LN2 = math.log(2) - RCP_LN = 1 / math.log(2) - max_scores_base2 = max_scores * RCP_LN - softmax_lse_base2 = max_scores_base2 + torch.log2(sum_exp_scores) - softmax_lse = softmax_lse_base2 * LN2 - softmax_lse.squeeze_(-1) - else: - softmax_lse = max_scores + torch.log(sum_exp_scores) - softmax_lse = softmax_lse.squeeze(-1) - - if DEBUG_CORE: - print("softmax_lse:", softmax_lse, softmax_lse.shape) - - # Compute output - if is_paged: - # Compute output on-the-fly using paged v cache - nheads_q = p.shape[0] - L_q = p.shape[1] - nheads_v = v_paged.shape[2] # [num_blocks, block_size, nheads_v, head_dim] - head_dim = v_paged.shape[3] - - # Handle MQA/GQA for v - assert nheads_q % nheads_v == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_v ({nheads_v})" - v_group_size = nheads_q // nheads_v - - o = torch.zeros((nheads_q, L_q, head_dim), dtype=torch.float32, device=p.device) - - # Accumulate weighted v values - for kv_pos in range(L_k): - # Calculate which block and position within block - block_idx = kv_pos // paged_kv_block_size - within_block_idx = kv_pos % paged_kv_block_size - - # Get the physical block number from block_table - if block_table.dim() == 2: - physical_block = block_table[0, block_idx].item() - else: - physical_block = block_table[block_idx].item() - - # Access v values directly from paged cache - # v_paged shape: [num_blocks, block_size, nheads_v, head_dim] - v_vec = v_paged[physical_block, within_block_idx, :, :].to(torch.float32) # [nheads_v, head_dim] - - # For GQA/MQA, we need to repeat v_vec for each group - if v_group_size > 1: - # Expand v_vec to match query heads - # v_vec: [nheads_v, head_dim] -> [nheads_q, head_dim] - v_vec = v_vec.repeat_interleave(v_group_size, dim=0) - - # Weight by attention probabilities - # p is [nheads_q, L_q, L_k], we need p[:, :, kv_pos] which is [nheads_q, L_q] - # v_vec is [nheads_q, head_dim] - # We want to add p[:, :, kv_pos] * v_vec to each query position - weights = p[:, :, kv_pos].unsqueeze(-1) # [nheads_q, L_q, 1] - o += weights * v_vec.unsqueeze(1) # [nheads_q, L_q, head_dim] - else: - o = torch.matmul(p, v) - - # Debug output disabled - # if False: - # print(f"Output o shape: {o.shape}") - # print(f"Output o sample values: {o[0, 0, :5]}") - - if DEBUG_CORE: - print("o:", o, o.shape) - - # cast back to original dtype - o = o.to(torch.float16) - # softmax_lse = softmax_lse.to(torch.float16) # NOTE: if you cast lse to fp16 it cause accuracy issues. keep fp32 - sd_mask = sd_mask.to(torch.float16) - - return o, softmax_lse, sd_mask - -def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, window_size_left, window_size_right, layout, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2): - """Compute reference output and softmax_lse using PyTorch's built-in function""" - - # Ensure the layout is 'bhsd' - if layout == "bshd": - q = q.transpose(1, 2).contiguous() - k = k.transpose(1, 2).contiguous() - v = v.transpose(1, 2).contiguous() - elif layout != "bhsd": - raise ValueError(f"Unknown layout {layout}") - - # Prepare tensors - batch_size, nheads_q, seq_len_q, head_dim = q.shape - batch_size, nheads_k, seq_len_k, head_dim = k.shape - group_size = nheads_q // nheads_k - if nheads_q % nheads_k != 0: - raise ValueError("nheads_q must be divisible by nheads_k") - - if group_size != 1: - # MQA or GQA case - # Reshape q to [batch_size, nheads_k, group_size, seq_len_q, head_dim] - q = q.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) - # Expand k and v to match group_size - k = k.unsqueeze(2).expand(-1, -1, group_size, -1, -1) - v = v.unsqueeze(2).expand(-1, -1, group_size, -1, -1) - # Flatten the first three dimensions for computation - q = q.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) - k = k.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) - v = v.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) - else: - q = q.reshape(batch_size * nheads_q, seq_len_q, head_dim) - k = k.reshape(batch_size * nheads_k, seq_len_k, head_dim) - v = v.reshape(batch_size * nheads_k, seq_len_k, head_dim) - - # Call the core attention function - o, softmax_lse, sd_mask = attention_forward_core_ref_impl( - q, k, v, sm_scale, causal, window_size_left, window_size_right, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2 - ) - - if group_size != 1: - # Reshape outputs back to original dimensions - o = o.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) - o = o.reshape(batch_size, nheads_q, seq_len_q, head_dim) - softmax_lse = softmax_lse.reshape(batch_size, nheads_k, group_size, seq_len_q) - softmax_lse = softmax_lse.reshape(batch_size, nheads_q, seq_len_q) - sd_mask = sd_mask.reshape(batch_size, nheads_k, group_size, seq_len_q, seq_len_k) - sd_mask = sd_mask.reshape(batch_size, nheads_q, seq_len_q, seq_len_k) - else: - # Standard case - o = o.reshape(batch_size, nheads_q, seq_len_q, head_dim) - softmax_lse = softmax_lse.reshape(batch_size, nheads_q, seq_len_q) - sd_mask = sd_mask.reshape(batch_size, nheads_q, seq_len_q, seq_len_k) - - # Restore original layout if necessary - if layout == "bshd": - o = o.transpose(1, 2) - - return o, softmax_lse, sd_mask - - -def attention_varlen_forward_pytorch_ref_impl( - q, - k, - v, - sm_scale, - causal, - window_size_left, - window_size_right, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2 -): - # Ensure the layout is 'thd' - if layout != 'thd': - raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.") - - batch_size = cu_seqlens_q.shape[0] - 1 - nheads_q, nheads_k = q.shape[1], k.shape[1] - head_dim = q.shape[2] - - # Pre-allocate outputs - total_L_q = q.shape[0] - total_L_k = k.shape[0] - - o = torch.zeros((total_L_q, nheads_q, head_dim), dtype=q.dtype, device=q.device) - softmax_lse = torch.zeros((nheads_q, total_L_q), dtype=torch.float32, device=q.device) - sd_mask = torch.zeros((batch_size, nheads_q, max_seqlen_q, max_seqlen_k), dtype=torch.float32, device=q.device) - - # Compute group_size for MQA/GQA handling - group_size = nheads_q // nheads_k - if nheads_q % nheads_k != 0: - raise ValueError("nheads_q must be divisible by nheads_k") - - for i in range(batch_size): - # Get the start and end indices for the current sequence - start_q = cu_seqlens_q[i].item() - end_q = cu_seqlens_q[i + 1].item() - start_k = cu_seqlens_k[i].item() - end_k = cu_seqlens_k[i + 1].item() - - seqlen_q = end_q - start_q - seqlen_k = end_k - start_k - - if DEBUG: - print(f"Batch {i} with seqlen_q = {seqlen_q}, seqlen_k = {seqlen_k}, Hq= {nheads_q}, Hk = {nheads_k}") - - # Extract q_i, k_i, v_i - q_i = q[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] - k_i = k[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] - v_i = v[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] - - # Permute to [nheads, L_q_i, head_dim] - q_i = q_i.permute(1, 0, 2) - k_i = k_i.permute(1, 0, 2) - v_i = v_i.permute(1, 0, 2) - - # Handle MQA/GQA by adjusting shapes based on group_size - if group_size != 1: - # Reshape q_i to [nheads_k, group_size, L_q_i, head_dim] - q_i = q_i.reshape(nheads_k, group_size, seqlen_q, head_dim) - # Expand k_i and v_i to match group_size - k_i = k_i.unsqueeze(1).expand(-1, group_size, -1, -1) - v_i = v_i.unsqueeze(1).expand(-1, group_size, -1, -1) - # Flatten the first two dimensions for computation - q_i = q_i.reshape(nheads_k * group_size, seqlen_q, head_dim) - k_i = k_i.reshape(nheads_k * group_size, seqlen_k, head_dim) - v_i = v_i.reshape(nheads_k * group_size, seqlen_k, head_dim) - else: - # Standard case - q_i = q_i.reshape(nheads_q, seqlen_q, head_dim) - k_i = k_i.reshape(nheads_k, seqlen_k, head_dim) - v_i = v_i.reshape(nheads_k, seqlen_k, head_dim) - - if alibi_slopes is not None: - alibi_slopes_i = alibi_slopes[i] - else: - alibi_slopes_i = None - - # Call the core attention function for this sequence - o_i, softmax_lse_i, sd_mask_i = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, window_size_left, window_size_right, dropout_p, philox_seed, philox_offset, alibi_slopes_i, use_exp2) - - # Reshape outputs back to original dimensions - if group_size != 1: - # Reshape outputs to [nheads_k, group_size, seqlen_q, head_dim] - o_i = o_i.reshape(nheads_k, group_size, seqlen_q, head_dim) - # Combine the first two dimensions back to nheads_q - o_i = o_i.reshape(nheads_q, seqlen_q, head_dim) - # Reshape softmax_lse_i similarly - softmax_lse_i = softmax_lse_i.reshape(nheads_k, group_size, seqlen_q) - softmax_lse_i = softmax_lse_i.reshape(nheads_q, seqlen_q) - else: - # Outputs are already in the correct shape - pass - - # Convert back to 'thd' layout - o_i = o_i.permute(1, 0, 2) # [L_q_i, nheads_q, head_dim] - sd_mask_i = sd_mask_i # [nheads_q, L_q_i, L_k_i] - - # Place outputs in pre-allocated tensors - o[start_q:end_q, :, :] = o_i - softmax_lse[:, start_q:end_q] = softmax_lse_i - sd_mask[i, :, :seqlen_q, :seqlen_k] = sd_mask_i - - return o, softmax_lse, sd_mask - -def attention_prefill_forward_ref_impl( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - window_size_left: int, - window_size_right: int, - layout: Literal["bshd", "bhsd", "thd"], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - use_exp2: bool -): - # compute reference - if layout == "thd": - o_ref, softmax_lse_ref, sd_mask_ref = attention_varlen_forward_pytorch_ref_impl( - q.clone(), - k.clone(), - v.clone(), - sm_scale, - causal, - window_size_left, - window_size_right, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2, - ) - else: - o_ref, softmax_lse_ref, sd_mask_ref = attention_vanilla_forward_pytorch_ref_impl( - q.clone(), - k.clone(), - v.clone(), - sm_scale, - causal, - window_size_left, - window_size_right, - layout, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - use_exp2) - - # copy back to ouput tensor - out.copy_(o_ref.to(out.dtype)) - - return softmax_lse_ref, sd_mask_ref - -def attention_decode_forward_ref_impl( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - k_new: Optional[torch.Tensor], - v_new: Optional[torch.Tensor], - out: torch.Tensor, - sm_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - alibi_slopes: Optional[torch.Tensor], - layout: Literal["bshd"], - cache_seqlens: Optional[torch.Tensor], - cache_batch_idx: Optional[torch.Tensor], - block_table: Optional[torch.Tensor] = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, -): - """Compute reference output for decode attention using PyTorch's built-in functions""" - - if False: # Permanently disabled old debug output - pass - # print("\n========== attention_decode_forward_ref_impl inputs ==========") - # print(f"q shape: {q.shape}, dtype: {q.dtype}, device: {q.device}") - print(f"q values:\n{q}") - print(f"\nk_cache shape: {k_cache.shape}, dtype: {k_cache.dtype}, device: {k_cache.device}") - print(f"k_cache values:\n{k_cache}") - print(f"\nv_cache shape: {v_cache.shape}, dtype: {v_cache.dtype}, device: {v_cache.device}") - print(f"v_cache values:\n{v_cache}") - print(f"\nk_new: {k_new.shape if k_new is not None else None}, dtype: {k_new.dtype if k_new is not None else None}") - if k_new is not None: - print(f"k_new values:\n{k_new}") - print(f"\nv_new: {v_new.shape if v_new is not None else None}, dtype: {v_new.dtype if v_new is not None else None}") - if v_new is not None: - print(f"v_new values:\n{v_new}") - print(f"\nout shape: {out.shape}, dtype: {out.dtype}, device: {out.device}") - print(f"out values:\n{out}") - print(f"\nsm_scale: {sm_scale}") - print(f"causal: {causal}") - print(f"window_size_left: {window_size_left}") - print(f"window_size_right: {window_size_right}") - print(f"\nalibi_slopes: {alibi_slopes.shape if alibi_slopes is not None else None}, dtype: {alibi_slopes.dtype if alibi_slopes is not None else None}") - if alibi_slopes is not None: - print(f"alibi_slopes values:\n{alibi_slopes}") - print(f"\nlayout: {layout}") - print(f"cache_seqlens: {cache_seqlens}") - if cache_seqlens is not None and torch.is_tensor(cache_seqlens): - print(f"cache_seqlens values: {cache_seqlens}") - print(f"cache_batch_idx: {cache_batch_idx}") - if cache_batch_idx is not None: - print(f"cache_batch_idx values: {cache_batch_idx}") - print(f"\nblock_table: {block_table.shape if block_table is not None else None}, dtype: {block_table.dtype if block_table is not None else None}") - if block_table is not None: - print(f"block_table values:\n{block_table}") - print("=" * 60) - - # get batch size before any layout conversion - batch_size = q.shape[0] - - # Determine if we're in paged KV mode - is_paged = block_table is not None - if is_paged: - # Infer block size from cache shape - # k_cache shape for paged: [num_blocks, block_size, nheads, head_dim] - paged_kv_block_size = k_cache.shape[1] - else: - paged_kv_block_size = None - - # handle cache_batch_idx - if cache_batch_idx is not None: - # remap batch indices for cache access - batch_indices = cache_batch_idx - else: - batch_indices = torch.arange(batch_size, device=q.device) - - # copy new keys and values into cache if provided (before any layout conversion) - if k_new is not None and v_new is not None: - if is_paged: - # For paged KV cache, we need to update the blocks with new k/v values - _, seq_len_new, _, _ = k_new.shape # shape is [batch, seq_len, nheads, head_dim] for bshd layout - - for b in range(batch_size): - # Determine where to place new k/v in cache - if cache_seqlens is not None: - if torch.is_tensor(cache_seqlens): - start_pos = cache_seqlens[b].item() - else: - start_pos = cache_seqlens - else: - start_pos = 0 - - # For each new position, find the corresponding block and update it - for pos_offset in range(seq_len_new): - kv_pos = start_pos + pos_offset - - # Calculate which block and position within block - block_idx = kv_pos // paged_kv_block_size - within_block_idx = kv_pos % paged_kv_block_size - - # Get the physical block number from block_table - physical_block = block_table[b, block_idx].item() - - # Update the k and v values in the paged cache - # k_cache shape: [num_blocks, block_size, nheads, head_dim] - # k_new shape: [batch, seq_len, nheads, head_dim] - k_cache[physical_block, within_block_idx, :, :] = k_new[b, pos_offset, :, :] - v_cache[physical_block, within_block_idx, :, :] = v_new[b, pos_offset, :, :] - else: - _, seq_len_new, _, _ = k_new.shape # shape is [batch, seq_len, nheads, head_dim] for bshd layout - - for b in range(batch_size): - cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices - - # determine where to place new k/v in cache - if cache_seqlens is not None: - if torch.is_tensor(cache_seqlens): - start_pos = cache_seqlens[b].item() - else: - start_pos = cache_seqlens - else: - # if no cache_seqlens, assume we're filling from the beginning - start_pos = 0 - - end_pos = start_pos + seq_len_new - - # copy new keys and values into cache (both are in bshd layout) - k_cache[cache_idx, start_pos:end_pos, :, :] = k_new[b, :, :, :] - v_cache[cache_idx, start_pos:end_pos, :, :] = v_new[b, :, :, :] - - # ensure the layout is 'bhsd' - if layout == "bshd": - q = q.transpose(1, 2).contiguous() - if not is_paged: - k_cache = k_cache.transpose(1, 2).contiguous() - v_cache = v_cache.transpose(1, 2).contiguous() - elif layout != "bhsd": - raise ValueError(f"Unknown layout {layout}") - - # prepare tensors - batch_size_q, nheads_q, seq_len_q, head_dim = q.shape - - if is_paged: - # For paged cache: [num_blocks, block_size, nheads, head_dim] - num_blocks, block_size, nheads_k, head_dim_k = k_cache.shape - _, _, nheads_v, head_dim_v = v_cache.shape - max_cache_len = None # Not directly available in paged mode - batch_size_cache = None # Not applicable in paged mode - else: - batch_size_cache, nheads_k, max_cache_len, head_dim_k = k_cache.shape - _, nheads_v, _, head_dim_v = v_cache.shape - - # validate dimensions - assert head_dim == head_dim_k == head_dim_v, f"Head dimensions must match: {head_dim}, {head_dim_k}, {head_dim_v}" - - # handle MQA/GQA - group_size = nheads_q // nheads_k - if nheads_q % nheads_k != 0: - raise ValueError("nheads_q must be divisible by nheads_k") - - # handle cache_batch_idx - if cache_batch_idx is not None: - # remap batch indices for cache access - batch_indices = cache_batch_idx - else: - batch_indices = torch.arange(batch_size, device=q.device) - - # prepare outputs - o = torch.zeros_like(q) - softmax_lse = torch.zeros((batch_size, nheads_q, seq_len_q), dtype=torch.float32, device=q.device) - - # process each batch element - for b in range(batch_size): - if not is_paged: - cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices - - # determine valid cache length for this batch element - if cache_seqlens is not None: - if torch.is_tensor(cache_seqlens): - cache_len = cache_seqlens[b].item() - if k_new is not None: - _, seq_len_new, _, _ = k_new.shape - cache_len += seq_len_new - else: - cache_len = cache_seqlens - if k_new is not None: - _, seq_len_new, _, _ = k_new.shape - cache_len += seq_len_new - else: - if is_paged: - # For paged mode, we need cache_seqlens to know the valid length - raise ValueError("cache_seqlens must be provided for paged KV cache") - else: - cache_len = max_cache_len - - if is_paged: - # For paged KV cache, pass the cache and block table directly - # Extract block table for this batch element - block_table_b = block_table[b:b+1, :] # [1, num_blocks] - k_b = k_cache # Pass entire paged cache - v_b = v_cache # Pass entire paged cache - q_b = q[b:b+1, :, :, :] # [1, nheads_q, seq_len_q, head_dim] - - # For paged mode with MQA/GQA, we handle expansion in the core function - # Just reshape q for now - q_b = q_b.reshape(nheads_q, seq_len_q, head_dim) - else: - # Standard non-paged mode - k_b = k_cache[cache_idx, :, :, :] # [nheads_k, max_cache_len, head_dim] - v_b = v_cache[cache_idx, :, :, :] # [nheads_v, max_cache_len, head_dim] - q_b = q[b:b+1, :, :, :] # [1, nheads_q, seq_len_q, head_dim] - block_table_b = None - - # handle MQA/GQA by expanding k and v - if group_size != 1: - # expand k and v to match q's number of heads - k_b = k_b.unsqueeze(1).expand(-1, group_size, -1, -1) - k_b = k_b.reshape(nheads_q, max_cache_len, head_dim) - - v_b = v_b.unsqueeze(1).expand(-1, group_size, -1, -1) - v_b = v_b.reshape(nheads_q, max_cache_len, head_dim) - - # reshape for attention_forward_core_ref_impl - q_b = q_b.reshape(nheads_q, seq_len_q, head_dim) - - # handle alibi slopes for this batch - alibi_slopes_b = None - if alibi_slopes is not None: - if alibi_slopes.dim() == 2: - alibi_slopes_b = alibi_slopes[b] - else: - alibi_slopes_b = alibi_slopes - - # call core attention function with cache information - o_b, softmax_lse_b, _ = attention_forward_core_ref_impl( - q_b, k_b, v_b, sm_scale, causal, window_size_left, window_size_right, - dropout_p=0.0, philox_seed=None, philox_offset=None, - alibi_slopes=alibi_slopes_b, use_exp2=True, - cache_seqlens=cache_len, # Pass valid cache length - block_table=block_table_b, # Pass block table for paged mode - paged_kv_block_size=paged_kv_block_size, # Pass block size for paged mode - ) - - # store outputs - o[b, :, :, :] = o_b.reshape(nheads_q, seq_len_q, head_dim) - softmax_lse[b, :, :] = softmax_lse_b.reshape(nheads_q, seq_len_q) - - # restore original layout if necessary - if layout == "bshd": - o = o.transpose(1, 2) - - # copy output to the provided tensor - out.copy_(o.to(out.dtype)) - - return softmax_lse \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py deleted file mode 100644 index 3dc443abf67..00000000000 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ /dev/null @@ -1,927 +0,0 @@ -import torch -import os -from .fwd_prefill import attention_prefill_forward_triton_impl -from .bwd_prefill_split import attention_prefill_backward_triton_split_impl -from .bwd_prefill_fused_atomics import attention_prefill_backward_triton_fused_atomics_impl -from .bwd_prefill_fused_no_atomics import attention_prefill_backward_triton_split_fused_no_atomics_impl -from .fwd_decode import attention_decode_forward_triton_impl -from .fwd_ref import attention_prefill_forward_ref_impl, attention_decode_forward_ref_impl -from .bwd_ref import attention_backward_pytorch_ref_impl -from .utils import DEBUG, USE_REF, MetaData, is_fp8 -from einops import rearrange, repeat -from flash_attn.layers.rotary import apply_rotary_emb -from typing import Literal, Optional, Union - - -USE_EXP2 = True -BWD_MODE = os.environ.get('BWD_MODE', 'fused_no_atomics').lower() - -def fwd(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: Optional[torch.Tensor], - alibi_slopes: Optional[torch.Tensor], - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - return_softmax: bool, - gen_: Optional[torch.Tensor] = None, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None - ): - - if DEBUG: - print() - print("flash_attn_triton_amd.py::fwd inputs") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("out:", out, out.shape if out is not None else None) - print("alibi_slopes:", alibi_slopes) - print("dropout_p:", dropout_p) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("softcap:", softcap) - print("return_softmax:", return_softmax) - print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) - print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) - print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) - - if is_fp8(q): - assert out is not None, "fp8 output tensor should be passed in." - if (descale_q is None) or (descale_k is None) or (descale_v is None): - import warnings - warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) - else: - out = torch.zeros_like(q) if out is None else out.zero_() - - # Setup metadata - metadata = MetaData(sm_scale=softmax_scale) - metadata.max_seqlens_q = q.shape[1] - metadata.max_seqlens_k = k.shape[1] - metadata.layout = "bshd" - - # get shape - batch, _ , nheads_q, _= q.shape - - if causal: - metadata.need_causal(True) - - if alibi_slopes is not None: - if alibi_slopes.dim() == 2: - pass - elif alibi_slopes.dim() == 1: - alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) - else: - raise ValueError(f"Alibi can be (nheads,) or (batch_size, nheads). Given tensor with shape {alibi_slopes.shape}") - metadata.need_alibi(alibi_slopes, batch, nheads_q) - - # store rng state - metadata.need_dropout(dropout_p, return_softmax) - rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast - - # check arguments - metadata.check_args(q, k, v, out) - - # call implementation - if USE_REF: - if DEBUG: - print("Using reference implementation") - softmax_lse_ref, sd_mask_ref = attention_prefill_forward_ref_impl( - q, - k, - v, - out, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - window_size_left, - window_size_right, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - USE_EXP2) - softmax_lse=softmax_lse_ref - sd_mask=sd_mask_ref - else: - if DEBUG: - print("Using Triton implementation") - softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( - q, - k, - v, - out, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - window_size_left, - window_size_right, - None, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - metadata.return_softmax, - USE_EXP2, - descale_q, - descale_k, - descale_v) - softmax_lse=softmax_lse_triton - sd_mask=sd_mask_triton - - if DEBUG: - print("flash_attn_triton_amd.py::fwd outputs") - print("o:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) - print("rng_state:", rng_state) - - return out, softmax_lse, sd_mask, rng_state - -def bwd( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - alibi_slopes: Optional[torch.Tensor], - dropout_p: float, - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - deterministic: bool, - gen_: Optional[torch.Tensor] = None, - rng_state:Optional[torch.Tensor] = None, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_o: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None, - descale_dq: Optional[torch.Tensor] = None, - descale_dk: Optional[torch.Tensor] = None, - descale_dv: Optional[torch.Tensor] = None, -): - if DEBUG: - print() - print("flash_attn_triton_amd.py::bwd inputs") - print("dout:", dout, dout.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape if dq is not None else None) - print("dk:", dk, dk.shape if dk is not None else None) - print("dv:", dv, dv.shape if dv is not None else None) - print("alibi_slopes:", alibi_slopes) - print("dropout_p:", dropout_p) - print("out:", out) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("deterministic:", deterministic) - print("gen_:", gen_) - print("rng_state:", rng_state) - print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) - print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) - print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) - print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) - print("descale_do:", descale_do, descale_do.shape if descale_do is not None else None) - print("descale_dq:", descale_dq, descale_dq.shape if descale_dq is not None else None) - print("descale_dk:", descale_dk, descale_dk.shape if descale_dk is not None else None) - print("descale_dv:", descale_dv, descale_dv.shape if descale_dv is not None else None) - - dq = torch.zeros_like(q) if dq is None else dq.zero_() - dk = torch.zeros_like(k) if dk is None else dk.zero_() - dv = torch.zeros_like(v) if dv is None else dv.zero_() - - # get shape - batch, _ , nheads_q, _= q.shape - - if dropout_p > 0.0: - assert rng_state is not None - philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() - else: - philox_seed, philox_offset = None, None - - if alibi_slopes is not None: - if alibi_slopes.dim() == 2: - pass - elif alibi_slopes.dim() == 1: - alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) - else: - raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") - - # call implementation - if USE_REF: - if DEBUG: - print("Using reference implementation") - - delta_ref = attention_backward_pytorch_ref_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - window_size_left, - window_size_right, - "bshd", - None, - None, - None, - None, - dropout_p, - philox_seed, - philox_offset, - USE_EXP2, - ) - delta = delta_ref - else: - if DEBUG: - print("Using Triton implementation") - if BWD_MODE == "split": - delta_triton = attention_prefill_backward_triton_split_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - "bshd", - None, - None, - None, - None, - dropout_p, - philox_seed, - philox_offset, - USE_EXP2, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - delta = delta_triton - elif BWD_MODE == "fused_atomics": - delta_triton = attention_prefill_backward_triton_fused_atomics_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - None, - None, - q.shape[1], - k.shape[1], - dropout_p, - philox_seed, - philox_offset, - descale_q, - descale_k, - descale_v, - descale_o, - True, - ) - delta = delta_triton - elif BWD_MODE == "fused_no_atomics": - delta_triton = attention_prefill_backward_triton_split_fused_no_atomics_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - "bshd", - None, - None, - None, - None, - dropout_p, - philox_seed, - philox_offset, - USE_EXP2, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - delta = delta_triton - else: - raise ValueError(f"Unknown bwd mode {BWD_MODE}") - - if DEBUG: - print("flash_attn_triton_amd.py::bwd outputs") - print("dv:", dv, dv.shape) - if is_fp8(dv): - print("descale_dv:", descale_dv, descale_dv.shape if descale_dv is not None else None) - print("dk:", dk, dk.shape) - if is_fp8(dk): - print("descale_dk:", descale_dk, descale_dk.shape if descale_dk is not None else None) - print("dq:", dq, dq.shape) - if is_fp8(dq): - print("descale_dq:", descale_dq, descale_dq.shape if descale_dq is not None else None) - return dq, dk, dv, delta - -def varlen_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - seqused_k: Optional[torch.Tensor], - leftpad_k: Optional[torch.Tensor], - block_table_: Optional[torch.Tensor], - alibi_slopes: Optional[torch.Tensor], - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - zero_tensors: bool , - causal: bool , - window_size_left: int, - window_size_right: int, - softcap: float, - return_softmax: bool, - gen_: Optional[torch.Tensor] = None, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None - ): - - if DEBUG: - print() - print("flash_attn_triton_amd.py::varlen_fwd") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) - print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) - print("alibi_slopes:", alibi_slopes) - print("max_seqlen_q:", max_seqlen_q) - print("max_seqlen_k:", max_seqlen_k) - print("dropout_p:", dropout_p) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("gen_:", gen_) - print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) - print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) - print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) - - if is_fp8(q): - assert out is not None, "fp8 output tensor should be passed in." - if (descale_q is None) or (descale_k is None) or (descale_v is None): - import warnings - warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) - else: - out = torch.zeros_like(q) if out is None else out.zero_() - - # Setup metadata - metadata = MetaData(sm_scale=softmax_scale) - metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # set layout to "thd" and other metdata - assert metadata.layout is not None - - # get shape - batch = len(cu_seqlens_q) - 1 - _, nheads_q, _= q.shape - - if causal: - metadata.need_causal(True) - - if alibi_slopes is not None: - if alibi_slopes.dim() == 2: - pass - elif alibi_slopes.dim() == 1: - alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) - else: - raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") - metadata.need_alibi(alibi_slopes, batch, nheads_q) - - # store rng state - metadata.need_dropout(dropout_p, return_softmax) - rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast - - # Check arguments - metadata.check_args(q, k, v, out) - - # call implementation - if USE_REF: - if DEBUG: - print("Using reference implementation") - softmax_lse_ref, sd_mask_ref = attention_prefill_forward_ref_impl( - q, - k, - v, - out, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - window_size_left, - window_size_right, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - USE_EXP2) - softmax_lse=softmax_lse_ref - sd_mask=sd_mask_ref - else: - if DEBUG: - print("Using Triton implementation") - softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( - q, - k, - v, - out, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - window_size_left, - window_size_right, - None, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - metadata.return_softmax, - USE_EXP2, - descale_q, - descale_k, - descale_v) - softmax_lse=softmax_lse_triton - sd_mask=sd_mask_triton - - if DEBUG: - print("varlen_fwd outputs") - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) - - - return out, softmax_lse, sd_mask, rng_state - -def varlen_bwd( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - alibi_slopes: Optional[torch.Tensor], - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - softmax_scale: float, - zero_tensors: bool, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - deterministic: bool, - gen_ : Optional[torch.Tensor] = None, - rng_state: Optional[torch.Tensor] = None, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_o: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None, - descale_dq: Optional[torch.Tensor] = None, - descale_dk: Optional[torch.Tensor] = None, - descale_dv: Optional[torch.Tensor] = None, -): - if DEBUG: - print() - print("varlen_bwd") - print("dout:", dout, dout.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("out:", out) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape if dq is not None else None) - print("dk:", dk, dk.shape if dk is not None else None) - print("dv:", dv, dv.shape if dv is not None else None) - print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) - print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) - print("alibi_slopes:", alibi_slopes) - print("max_seqlen_q:", max_seqlen_q) - print("max_seqlen_k:", max_seqlen_k) - print("dropout_p:", dropout_p) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("deterministic:", deterministic) - print("gen_:", gen_) - print("rng_state:", rng_state) - print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) - print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) - print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) - print("descale_do:", descale_do, descale_do.shape if descale_do else None) - - dq = torch.zeros_like(q) if dq is None else dq.zero_() - dk = torch.zeros_like(k) if dk is None else dk.zero_() - dv = torch.zeros_like(v) if dv is None else dv.zero_() - - # get shape - batch = len(cu_seqlens_q) - 1 - _, nheads_q, _= q.shape - - if dropout_p > 0.0: - assert rng_state is not None - philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() - else: - philox_seed, philox_offset = None, None - - if alibi_slopes is not None: - if alibi_slopes.dim() == 2: - pass - elif alibi_slopes.dim() == 1: - alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) - else: - raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") - - # call implementation - if USE_REF: - if DEBUG: - print("Using reference implementation") - delta_ref = attention_backward_pytorch_ref_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - window_size_left, - window_size_right, - "thd", - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - USE_EXP2, - ) - delta = delta_ref - else: - if DEBUG: - print("Using Triton implementation") - if BWD_MODE == "split": - delta_triton = attention_prefill_backward_triton_split_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - "thd", - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - USE_EXP2, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - delta = delta_triton - elif BWD_MODE == "fused_atomics": - delta_triton = attention_prefill_backward_triton_fused_atomics_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - descale_q, - descale_k, - descale_v, - descale_o, - True, - ) - delta = delta_triton - elif BWD_MODE == "fused_no_atomics": - delta_triton = attention_prefill_backward_triton_split_fused_no_atomics_impl( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - softmax_scale, - alibi_slopes, - causal, - "thd", - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - USE_EXP2, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - delta = delta_triton - else: - raise ValueError(f"Unknown bwd mode {BWD_MODE}") - - if DEBUG: - print("varlen_bwd outputs") - print("delta:", delta, delta.shape) - print("dv:", dv, dv.shape) - print("dk:", dk, dk.shape) - print("dq:", dq, dq.shape) - - return dq, dk, dv, delta - -def fwd_kvcache( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - k: Optional[torch.Tensor], - v: Optional[torch.Tensor], - cache_seqlens: Optional[Union[(int, torch.Tensor)]], - rotary_cos: Optional[torch.Tensor], - rotary_sin: Optional[torch.Tensor], - cache_batch_idx: Optional[torch.Tensor], - cache_leftpad: Optional[torch.Tensor], - block_table: Optional[torch.Tensor], - alibi_slopes: Optional[torch.Tensor], - out: Optional[torch.Tensor], - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - rotary_interleaved: bool, - num_splits: int - ): - - if DEBUG: - print() - print("flash_attn_triton_amd.py::fwd_kvcache inputs") - print("q:", q, q.shape) - print("k_cache:", k_cache, k_cache.shape) - print("v_cache:", v_cache, v_cache.shape) - print("k:", k, k.shape if k is not None else None) - print("v:", v, v.shape if v is not None else None) - print("cache_seqlens:", cache_seqlens ) - print("rotary_cos:",rotary_cos ) - print("rotary_sin:",rotary_sin) - print("cache_batch_idx:", cache_batch_idx) - print("cache_leftpad:", cache_leftpad) - print("block_table:", block_table) - print("alibi_slopes:", alibi_slopes) - print("out:", out) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("softcap:", softcap) - print("rotary_interleaved:", rotary_interleaved) - print("num_splits:", num_splits) - - # output - out = torch.zeros_like(q) if out is None else out.zero_() - - # fill metadata - metadata = MetaData(sm_scale=softmax_scale) - metadata.layout = "bshd" - metadata.max_seqlens_q = q.shape[1] - metadata.max_seqlens_k = k_cache.shape[1] - metadata.cache_batch_idx = cache_batch_idx - if isinstance(cache_seqlens, int): - metadata.cache_seqlens = torch.tensor(cache_seqlens, device=q.device) - else: - metadata.cache_seqlens = cache_seqlens - - # window_size can be a tensor sometimes - if isinstance(window_size_left, torch.Tensor): - metadata.window_size_left = int(window_size_left.item()) - else: - metadata.window_size_left = window_size_left - if isinstance(window_size_right, torch.Tensor): - metadata.window_size_right = int(window_size_right.item()) - else: - metadata.window_size_right = window_size_right - - k_new = k - v_new = v - - # get shape - batch, _ , nheads_q, _= q.shape - - if causal: - metadata.need_causal(True) - - if alibi_slopes is not None: - if alibi_slopes.dim() == 2: - pass - elif alibi_slopes.dim() == 1: - alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) - else: - raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") - metadata.need_alibi(alibi_slopes, batch, nheads_q) - - # rotary boolean - apply_rotary = torch.is_tensor(rotary_cos) and torch.is_tensor(rotary_sin) - if apply_rotary: - metadata.need_rotary(rotary_sin, rotary_cos, rotary_interleaved) - - # Rotary Embedding Implementation - if apply_rotary: - if metadata.causal or (window_size_left != -1 or window_size_right !=-1): # NOTE: when support is added. Add `or metadata.local` - q_ro = apply_rotary_emb( - q, - metadata.rotary_cos, - metadata.rotary_sin, - seqlen_offsets=metadata.cache_seqlens, - interleaved=metadata.rotary_interleaved, - ) - else: - q_ro = rearrange( - apply_rotary_emb( - rearrange(q, "b s h d -> b 1 (s h) d"), - metadata.rotary_cos, - metadata.rotary_sin, - seqlen_offsets=metadata.cache_seqlens, - interleaved=metadata.rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=metadata.max_seqlens_q, - ) - k_ro = apply_rotary_emb( - k_new, - metadata.rotary_cos, - metadata.rotary_sin, - seqlen_offsets=metadata.cache_seqlens, - interleaved=metadata.rotary_interleaved, - ) - - q, k_new = q_ro.to(q.dtype), k_ro.to(q.dtype) - - # launch kernel - if USE_REF: - if DEBUG: - print("Using reference implementation") - softmax_lse_ref = attention_decode_forward_ref_impl( - q, - k_cache, - v_cache, - k_new, - v_new, - out, - metadata.sm_scale, - metadata.causal, - metadata.window_size_left, - metadata.window_size_right, - metadata.alibi_slopes, - metadata.layout, - metadata.cache_seqlens, - metadata.cache_batch_idx, - block_table, - ) - softmax_lse=softmax_lse_ref - else: - if DEBUG: - print("Using Triton implementation") - softmax_lse_triton = attention_decode_forward_triton_impl( - q, - k_cache, - v_cache, - k_new, - v_new, - out, - metadata.sm_scale, - metadata.causal, - metadata.window_size_left, - metadata.window_size_right, - metadata.alibi_slopes, - metadata.layout, - metadata.cache_seqlens, - metadata.cache_batch_idx, - block_table, - ) - softmax_lse = softmax_lse_triton - - if DEBUG: - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - return out, softmax_lse diff --git a/flash_attn/flash_attn_triton_amd/interface_fa_v3.py b/flash_attn/flash_attn_triton_amd/interface_fa_v3.py deleted file mode 100755 index be8e2d3cbeb..00000000000 --- a/flash_attn/flash_attn_triton_amd/interface_fa_v3.py +++ /dev/null @@ -1,660 +0,0 @@ -import torch -import os -from .fwd_prefill import attention_prefill_forward_triton_impl -from .bwd_prefill_split import attention_prefill_backward_triton_split_impl -from .bwd_prefill_fused_atomics import attention_prefill_backward_triton_fused_atomics_impl -from .bwd_prefill_fused_no_atomics import attention_prefill_backward_triton_split_fused_no_atomics_impl -from .fwd_decode import attention_decode_forward_triton_impl -from .fwd_ref import attention_prefill_forward_ref_impl, attention_decode_forward_ref_impl -from .bwd_ref import attention_backward_pytorch_ref_impl -from .utils import DEBUG, USE_REF, MetaData, is_fp8 -from einops import rearrange, repeat -from flash_attn.layers.rotary import apply_rotary_emb -from typing import Optional, Union, Tuple - -USE_EXP2 = True -BWD_MODE = os.environ.get('BWD_MODE', 'fused_no_atomics').lower() -USE_DECODE_PATH = os.environ.get('FLASH_ATTENTION_V3_USE_DECODE', '0') == '1' - -def fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - k_new: Optional[torch.Tensor], - v_new: Optional[torch.Tensor], - qv: Optional[torch.Tensor], - out: Optional[torch.Tensor], - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - cu_seqlens_k_new: Optional[torch.Tensor], - seqused_q: Optional[torch.Tensor], - seqused_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - page_table: Optional[torch.Tensor], - kv_batch_idx: Optional[torch.Tensor], - leftpad_k: Optional[torch.Tensor], - rotary_cos: Optional[torch.Tensor], - rotary_sin: Optional[torch.Tensor], - seqlens_rotary: Optional[torch.Tensor], - q_descale: Optional[torch.Tensor], - k_descale: Optional[torch.Tensor], - v_descale: Optional[torch.Tensor], - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - attention_chunk: int, - softcap: float, - rotary_interleaved: bool, - scheduler_metadata=None, - num_splits: int = 1, - pack_gqa=None, - sm_margin: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Flash Attention v3 forward pass compatible interface for AMD Triton implementation. - - This function maps v3 parameters to the existing AMD Triton implementation. - """ - - if DEBUG: - print() - print("interface_fa_v3.py::fwd inputs") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("k_new:", k_new, k_new.shape if k_new is not None else None) - print("v_new:", v_new, v_new.shape if v_new is not None else None) - print("qv:", qv, qv.shape if qv is not None else None) - print("out:", out, out.shape if out is not None else None) - print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape if cu_seqlens_q is not None else None) - print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape if cu_seqlens_k is not None else None) - print("cu_seqlens_k_new:", cu_seqlens_k_new, cu_seqlens_k_new.shape if cu_seqlens_k_new is not None else None) - print("seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None) - print("seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None) - print("max_seqlen_q:", max_seqlen_q) - print("max_seqlen_k:", max_seqlen_k) - print("page_table:", page_table, page_table.shape if page_table is not None else None) - print("kv_batch_idx:", kv_batch_idx, kv_batch_idx.shape if kv_batch_idx is not None else None) - print("leftpad_k:", leftpad_k, leftpad_k.shape if leftpad_k is not None else None) - print("rotary_cos:", rotary_cos, rotary_cos.shape if rotary_cos is not None else None) - print("rotary_sin:", rotary_sin, rotary_sin.shape if rotary_sin is not None else None) - print("seqlens_rotary:", seqlens_rotary, seqlens_rotary.shape if seqlens_rotary is not None else None) - print("q_descale:", q_descale, q_descale.shape if q_descale is not None else None) - print("k_descale:", k_descale, k_descale.shape if k_descale is not None else None) - print("v_descale:", v_descale, v_descale.shape if v_descale is not None else None) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("attention_chunk:", attention_chunk) - print("softcap:", softcap) - print("rotary_interleaved:", rotary_interleaved) - print("scheduler_metadata:", scheduler_metadata) - print("num_splits:", num_splits) - print("pack_gqa:", pack_gqa) - print("sm_margin:", sm_margin) - - # Handle qv packed input - if qv is not None: - raise NotImplementedError("QV packed input is not yet supported in the AMD Triton backend") - - - # Handle softcap - if softcap != 0.0: - raise NotImplementedError(f"Softcap is not yet supported in the AMD Triton backend (got softcap={softcap}, expected 0.0)") - - # Handle attention_chunk - if attention_chunk != 0 and attention_chunk != 1: - raise NotImplementedError(f"attention_chunk is not yet supported in the AMD Triton backend (got attention_chunk={attention_chunk})") - - - # Handle scheduler metadata - if scheduler_metadata is not None: - raise NotImplementedError("Scheduler metadata is not yet supported in the AMD Triton backend") - - # Handle pack_gqa - if pack_gqa is not None and pack_gqa is not False: - raise NotImplementedError(f"pack_gqa is not yet supported in the AMD Triton backend (got pack_gqa={pack_gqa})") - - # Handle num_splits - if num_splits != 1: - raise NotImplementedError(f"Split attention (num_splits > 1) is not yet supported in the AMD Triton backend (got num_splits={num_splits})") - - # Handle sm_margin - if sm_margin != 0: - raise NotImplementedError(f"sm_margin is not yet supported in the AMD Triton backend (got sm_margin={sm_margin}, expected 0)") - - # Handle leftpad_k - if leftpad_k is not None: - raise NotImplementedError("Left padding (leftpad_k) is not yet supported in the AMD Triton backend") - - # Handle cu_seqlens_k_new - if cu_seqlens_k_new is not None: - raise NotImplementedError("cu_seqlens_k_new is not yet supported in the AMD Triton backend") - - # if seqlens_rotary is not None: - # raise NotImplementedError("seqlens_rotary is not yet supported in the AMD Triton backend") - - # Setup metadata - metadata = MetaData(sm_scale=softmax_scale) - - - # Handle variable length sequences first to determine layout - # Determine layout based on tensor dimensions and cu_seqlens presence - if cu_seqlens_q is not None: - # Q has variable length - check tensor dimensions to confirm - if len(q.shape) == 3: # [total_seqlen, nheads, head_dim] - metadata.layout = "thd" - metadata.varlen = True - metadata.cu_seqlens_q = cu_seqlens_q - metadata.max_seqlens_q = max_seqlen_q - - # K might be varlen or batch mode - if cu_seqlens_k is not None: - metadata.cu_seqlens_k = cu_seqlens_k - metadata.max_seqlens_k = max_seqlen_k - else: - # K is in batch mode while Q is varlen (KV cache scenario) - metadata.cu_seqlens_k = None - metadata.max_seqlens_k = k.shape[1] if len(k.shape) == 4 else max_seqlen_k - else: - raise ValueError(f"cu_seqlens_q provided but q has shape {q.shape}, expected 3D tensor for varlen") - else: - # Regular batch mode - metadata.layout = "bshd" - metadata.varlen = False - metadata.cu_seqlens_q = None - metadata.cu_seqlens_k = None - metadata.max_seqlens_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q - metadata.max_seqlens_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k - - # Now determine if we should use decode or prefill kernel - # Decode kernel should be used for KV cache scenarios where: - # 1. k_new/v_new are provided - incremental KV cache update (primary KV cache indicator) - # 2. kv_batch_idx is provided - KV cache batch indexing (primary KV cache indicator) - # 3. seqused_k without seqused_q - indicates KV cache fill levels (not varlen masking) - # Note: In varlen, both seqused_q and seqused_k are used for sequence masking - # In KV cache, only seqused_k is used to track cache fill levels - if USE_DECODE_PATH: - # Force decode path - use_decode = True - else: - # Detect KV cache scenarios: - # - Clear KV cache indicators (k_new, v_new, kv_batch_idx) - # - OR seqused_k without seqused_q (KV cache fill tracking, not varlen masking) - use_decode = ( - k_new is not None or # Have new KV to append (KV cache indicator) - v_new is not None or # Have new KV to append (KV cache indicator) - kv_batch_idx is not None or # Have KV cache batch indexing (KV cache indicator) - (seqused_k is not None and seqused_q is None) # KV cache fill levels (not varlen) - ) - - # Check for unsupported features with decode kernel - if use_decode: - if metadata.layout == "thd": - raise NotImplementedError("Varlen is not yet supported with the decode kernel in the AMD Triton backend") - if kv_batch_idx is not None: - raise NotImplementedError("kv_batch_idx is not yet supported with the decode kernel in the AMD Triton backend") - - - if out is None: - out_dtype = torch.float32 if is_fp8(q) else q.dtype - if metadata.layout == "bshd": - out = torch.zeros(q.shape[0], q.shape[1], q.shape[2], v.shape[-1], dtype=out_dtype, device=q.device) - elif metadata.layout == "thd": - out = torch.zeros(q.shape[0], q.shape[1], v.shape[-1], dtype=out_dtype, device=q.device) - else: - raise ValueError(f"Unsupported layout: {metadata.layout}. Only 'bshd' and 'thd' layouts are supported.") - else: - out = out.zero_() - - if is_fp8(q): - if (q_descale is None) or (k_descale is None) or (v_descale is None): - import warnings - warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) - - # Get shape - if metadata.layout == "bshd": - batch, _, nheads_q, _ = q.shape - else: # "thd" layout for varlen - _, nheads_q, _ = q.shape - batch = len(cu_seqlens_q) - 1 if cu_seqlens_q is not None else 1 - - # Handle causal mask - if causal: - metadata.need_causal(True) - - # Handle alibi slopes (not directly supported in v3 interface, but we'll keep the logic) - alibi_slopes = None # V3 doesn't have alibi_slopes in the signature - if alibi_slopes is not None: - if alibi_slopes.dim() == 2: - pass - elif alibi_slopes.dim() == 1: - alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) - else: - raise ValueError(f"Alibi can be (nheads,) or (batch_size, nheads). Given tensor with shape {alibi_slopes.shape}") - metadata.need_alibi(alibi_slopes, batch, nheads_q) - - # Handle dropout (v3 doesn't have dropout in forward) - dropout_p = 0.0 - return_softmax = False - metadata.need_dropout(dropout_p, return_softmax) - - # Handle rotary embeddings - if rotary_cos is not None and rotary_sin is not None: - metadata.need_rotary(rotary_sin, rotary_cos, rotary_interleaved) - - # Apply rotary embeddings if provided - if metadata.causal or window_size_left != -1 or window_size_right != -1: - q_rot = apply_rotary_emb( - q, - rotary_cos, - rotary_sin, - seqlen_offsets=seqlens_rotary, - interleaved=rotary_interleaved, - ) - q = q_rot.to(q.dtype) - - if k_new is not None: - k_rot = apply_rotary_emb( - k_new, - rotary_cos, - rotary_sin, - seqlen_offsets=seqlens_rotary, - interleaved=rotary_interleaved, - ) - k_new = k_rot.to(k.dtype) - - # Store RNG state - rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) - - # Call implementation - if USE_REF: - if DEBUG: - print("Using reference implementation") - - if use_decode: - if DEBUG: - print(f"Using decode reference implementation ( layout={metadata.layout}, cache_seqlens={seqused_k is not None}, k_new={k_new is not None}, v_new={v_new is not None}, kv_batch_idx={kv_batch_idx is not None})") - # Use decode reference implementation - softmax_lse = attention_decode_forward_ref_impl( - q, - k, # k_cache - v, # v_cache - k_new, - v_new, - out, - metadata.sm_scale, - metadata.causal, - window_size_left, - window_size_right, - metadata.alibi_slopes, - metadata.layout, - seqused_k, # cache_seqlens - kv_batch_idx, # cache_batch_idx - page_table, # block_table - q_descale, - k_descale, - v_descale, - ) - else: - if DEBUG: - print("Using prefill reference implementation") - # Use prefill reference implementation - softmax_lse_ref, sd_mask_ref = attention_prefill_forward_ref_impl( - q, k, v, out, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - window_size_left, - window_size_right, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - USE_EXP2 - ) - softmax_lse = softmax_lse_ref - else: - if DEBUG: - print("Using Triton implementation") - - if use_decode: - if DEBUG: - print(f"Using Decode Triton implementation (cache_seqlens={seqused_k is not None}, k_new={k_new is not None}, v_new={v_new is not None}, kv_batch_idx={kv_batch_idx is not None})") - - # Use decode kernel for KV cache scenarios - # Note: seqused_k can serve as cache_seqlens in v3 - softmax_lse = attention_decode_forward_triton_impl( - q, - k, # k_cache in v2 terminology - v, # v_cache in v2 terminology - k_new, # New KV values to append to cache - v_new, # New KV values to append to cache - out, - metadata.sm_scale, - metadata.causal, - window_size_left, - window_size_right, - metadata.alibi_slopes, - metadata.layout, - seqused_k, # cache_seqlens - kv_batch_idx, # cache_batch_idx - page_table, # block_table for paged attention - q_descale, - k_descale, - v_descale, - ) - # Decode kernel returns only softmax_lse, not sd_mask - sd_mask_triton = None - else: - if DEBUG: - print("Using prefill Triton implementation") - # Use prefill kernel - softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( - q, k, v, out, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - window_size_left, - window_size_right, - None, # block_table - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.dropout_p, - metadata.philox_seed, - metadata.philox_offset, - metadata.return_softmax, - USE_EXP2, - q_descale, - k_descale, - v_descale, - seqused_q, - seqused_k, - ) - softmax_lse = softmax_lse_triton - - if DEBUG: - print("interface_fa_v3.py::fwd outputs") - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - - # Return format compatible with v3 - # V3 returns (out, softmax_lse, *rest) where rest can be empty or contain additional outputs - return out, softmax_lse - - -def bwd( - dout: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - softmax_lse: torch.Tensor, - dq: Optional[torch.Tensor], - dk: Optional[torch.Tensor], - dv: Optional[torch.Tensor], - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - seqused_q: Optional[torch.Tensor], - seqused_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - softmax_scale: float, - causal: bool, - window_size_left: int, - window_size_right: int, - softcap: float, - deterministic: bool, - sm_margin: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Flash Attention v3 backward pass compatible interface for AMD Triton implementation. - - This function maps v3 parameters to the existing AMD Triton implementation. - """ - - if DEBUG: - print() - print("interface_fa_v3.py::bwd inputs") - print("dout:", dout, dout.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape if dq is not None else None) - print("dk:", dk, dk.shape if dk is not None else None) - print("dv:", dv, dv.shape if dv is not None else None) - print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape if cu_seqlens_q is not None else None) - print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape if cu_seqlens_k is not None else None) - print("seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None) - print("seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None) - print("max_seqlen_q:", max_seqlen_q) - print("max_seqlen_k:", max_seqlen_k) - print("softmax_scale:", softmax_scale) - print("causal:", causal) - print("window_size_left:", window_size_left) - print("window_size_right:", window_size_right) - print("softcap:", softcap) - print("deterministic:", deterministic) - print("sm_margin:", sm_margin) - - # Check for unsupported features in backward pass - - # Handle softcap - if softcap != 0.0: - raise NotImplementedError(f"Softcap is not yet supported in the AMD Triton backend backward pass (got softcap={softcap}, expected 0.0)") - - # Handle sm_margin - if sm_margin != 0: - raise NotImplementedError(f"sm_margin is not yet supported in the AMD Triton backend backward pass (got sm_margin={sm_margin}, expected 0)") - - # Initialize gradient tensors if not provided - dq = torch.zeros_like(q) if dq is None else dq.zero_() - dk = torch.zeros_like(k) if dk is None else dk.zero_() - dv = torch.zeros_like(v) if dv is None else dv.zero_() - - # Determine layout based on cu_seqlens - if cu_seqlens_q is not None and cu_seqlens_k is not None: - # Variable length sequence mode - layout = "thd" - batch = len(cu_seqlens_q) - 1 - _, nheads_q, _ = q.shape - else: - # Regular batch mode - layout = "bshd" - batch, _, nheads_q, _ = q.shape - max_seqlen_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q - max_seqlen_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k - - # V3 backward doesn't have dropout or alibi slopes - dropout_p = 0.0 - philox_seed, philox_offset = None, None - alibi_slopes = None - - # For fp8, we would need descale factors, but v3 interface doesn't expose them - # So we'll pass None for now - descale_q = None - descale_k = None - descale_v = None - descale_o = None - descale_do = None - descale_dq = None - descale_dk = None - descale_dv = None - - # Call implementation - if USE_REF: - if DEBUG: - print("Using reference implementation") - delta_ref = attention_backward_pytorch_ref_impl( - dout, q, k, v, out, softmax_lse, - dq, dk, dv, - softmax_scale, - alibi_slopes, - causal, - window_size_left, - window_size_right, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - USE_EXP2, - ) - delta = delta_ref - else: - if DEBUG: - print("Using Triton implementation") - - if BWD_MODE == "split": - delta_triton = attention_prefill_backward_triton_split_impl( - dout, q, k, v, out, softmax_lse, - dq, dk, dv, - softmax_scale, - alibi_slopes, - causal, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - USE_EXP2, - descale_q, descale_k, descale_v, descale_o, - descale_do, descale_dq, descale_dk, descale_dv, - seqused_q, seqused_k, - ) - delta = delta_triton - elif BWD_MODE == "fused_atomics": - delta_triton = attention_prefill_backward_triton_fused_atomics_impl( - dout, q, k, v, out, softmax_lse, - dq, dk, dv, - softmax_scale, - alibi_slopes, - causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - descale_q, descale_k, descale_v, descale_o, - True, - ) - delta = delta_triton - elif BWD_MODE == "fused_no_atomics": - delta_triton = attention_prefill_backward_triton_split_fused_no_atomics_impl( - dout, q, k, v, out, softmax_lse, - dq, dk, dv, - softmax_scale, - alibi_slopes, - causal, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - USE_EXP2, - descale_q, descale_k, descale_v, descale_o, - descale_do, descale_dq, descale_dk, descale_dv, - seqused_q, seqused_k, - ) - delta = delta_triton - else: - raise ValueError(f"Unknown bwd mode {BWD_MODE}") - - if DEBUG: - print("interface_fa_v3.py::bwd outputs") - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) - print("delta:", delta, delta.shape if delta is not None else None) - - # V3 expects (dq, dk, dv, softmax_d, *rest) - # delta is the softmax_d in this case - return dq, dk, dv, delta - - -def fwd_combine( - out_partial: torch.Tensor, - lse_partial: torch.Tensor, - out: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, -) -> torch.Tensor: - """ - Combine partial outputs from split attention computation. - - This is used when num_splits > 1 to combine the partial results. - - Args: - out_partial: Partial output tensor from split computation - lse_partial: Partial log-sum-exp tensor - out: Optional output tensor to write to - out_dtype: Optional dtype for output - - Returns: - Combined output tensor - """ - raise NotImplementedError("fwd_combine is not yet implemented in the AMD Triton backend") - - -def get_scheduler_metadata( - batch_size: int, - max_seqlen_q: int, - max_seqlen_k: int, - num_heads_q: int, - num_heads_kv: int, - headdim: int, - headdim_v: int, - qkv_dtype: torch.dtype, - cache_seqlens: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k: Optional[torch.Tensor] = None, - cu_seqlens_k_new: Optional[torch.Tensor] = None, - seqused_q: Optional[torch.Tensor] = None, - cache_leftpad: Optional[torch.Tensor] = None, - page_size: Optional[int] = None, - max_seqlen_k_new: int = 0, - causal: bool = False, - window_size_left: int = -1, - window_size_right: int = -1, - attention_chunk: int = 0, - has_softcap: bool = False, - num_splits: int = 0, - pack_gqa: Optional[bool] = None, - sm_margin: int = 0, -): - """ - Get scheduler metadata for optimized kernel selection. - - This function is used to precompute metadata for kernel scheduling in FA3. - The AMD Triton backend currently doesn't use scheduler metadata, so this - raises an error. - - Args: - Various attention parameters used for scheduling decisions - - Returns: - None - scheduler metadata is not used in AMD Triton backend - """ - raise NotImplementedError("get_scheduler_metadata is not supported in the AMD Triton backend yet.") \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/interface_v2.py b/flash_attn/flash_attn_triton_amd/interface_v2.py new file mode 100644 index 00000000000..134c4a76c12 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/interface_v2.py @@ -0,0 +1,674 @@ +import torch +import os +from typing import Optional, Union +from .fwd_prefill import attention_forward_prefill_triton_impl +from .fwd_decode import attention_forward_decode_triton_impl +from .bwd import attention_backward_triton_impl +from .utils import DEBUG, USE_EXP2, BWD_MODE, PHILOX_SEED, PHILOX_OFFSET + + +def fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, +): + + # Reject FP8 tensors (FA2 AMD path does not support FP8) + if str(q.dtype).startswith("torch.float8"): + raise NotImplementedError( + "FP8 tensors are not supported in the AMD Triton FA2 interface. Use the FA3 path instead." + ) + + # Unsupported features assertions (keep behavior explicit like v3 shim) + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in the AMD Triton FA2 interface (expected 0.0)." + ) + + if DEBUG: + print() + print("flash_attn_triton_amd.py::fwd inputs") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("out:", out, out.shape if out is not None else None) + print("alibi_slopes:", alibi_slopes) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("return_softmax:", return_softmax) + out = torch.zeros_like(q) if out is None else out.zero_() + + # Layout / shapes + layout = "bshd" + max_seqlen_q = q.shape[1] + max_seqlen_k = k.shape[1] + batch, _, nheads_q, _ = q.shape + + # Normalize / validate alibi + if alibi_slopes is not None: + if alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + assert alibi_slopes.is_cuda and alibi_slopes.dim() == 2 + assert alibi_slopes.shape == (batch, nheads_q) + + # Dropout + RNG seed + philox_seed, philox_offset = PHILOX_SEED, PHILOX_OFFSET + rng_state = torch.as_tensor([philox_seed, philox_offset]) + + # argument checks + assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4 + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert q.dtype == k.dtype == v.dtype + assert out.shape[:-1] == q.shape[:-1] and out.shape[-1] == v.shape[-1] + nheads_k = k.shape[2] + assert (nheads_q % nheads_k) == 0 + + # call implementation + if DEBUG: + print("Using Triton implementation") + softmax_lse, sd_mask = attention_forward_prefill_triton_impl( + q, + k, + v, + out, + softmax_scale, + alibi_slopes, + causal, + window_size_left, + window_size_right, + None, + layout, + None, + None, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + return_softmax, + USE_EXP2, + None, + None, + None, + ) + + if DEBUG: + print("flash_attn_triton_amd.py::fwd outputs") + print("o:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None) + print("rng_state:", rng_state) + + # --- Assertions (shape + dtype contracts) --- + # out: (B, Sq, Hq, D) + assert out.shape == q.shape, f"[fwd] out shape {out.shape} != q shape {q.shape}" + # softmax_lse: (B, Hq, Sq) + expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[fwd] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"[fwd] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + if return_softmax: + # sd_mask: (B, Hq, Sq, Sk) + assert sd_mask is not None, "[fwd] return_softmax=True but sd_mask is None" + assert sd_mask.dim() == 4, f"[fwd] sd_mask dim {sd_mask.dim()} != 4" + assert ( + sd_mask.shape[0] == q.shape[0] + and sd_mask.shape[1] == q.shape[2] + and sd_mask.shape[2] == q.shape[1] + ), f"[fwd] sd_mask leading dims {sd_mask.shape[:3]} mismatch (B,Hq,Sq) {(q.shape[0], q.shape[2], q.shape[1])}" + else: + assert sd_mask is None, "[fwd] return_softmax=False but sd_mask is not None" + + return out, softmax_lse, sd_mask, rng_state + + +def bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + gen_: Optional[torch.Tensor] = None, + rng_state: Optional[torch.Tensor] = None, +): + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in the AMD Triton FA2 interface (expected 0.0)." + ) + + if DEBUG: + print() + print("flash_attn_triton_amd.py::bwd inputs") + print("dout:", dout, dout.shape) + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("dq:", dq, dq.shape if dq is not None else None) + print("dk:", dk, dk.shape if dk is not None else None) + print("dv:", dv, dv.shape if dv is not None else None) + print("alibi_slopes:", alibi_slopes) + print("dropout_p:", dropout_p) + print("out:", out) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("deterministic:", deterministic) + print("gen_:", gen_) + print("rng_state:", rng_state) + + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() + + # get shape + batch, _, nheads_q, _ = q.shape + + # Upstream change: base seeding logic on provided rng_state instead of dropout probability. + if rng_state is not None: + philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() + else: + philox_seed, philox_offset = None, None + + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") + + # call implementation + if DEBUG: + print("Using Triton implementation") + delta = attention_backward_triton_impl( + do=dout, + q=q, + k=k, + v=v, + o=out, + softmax_lse=softmax_lse, + dq=dq, + dk=dk, + dv=dv, + sm_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + layout="bshd", + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=q.shape[1], + max_seqlen_k=k.shape[1], + seqused_q=None, + seqused_k=None, + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset=philox_offset, + use_exp2=USE_EXP2, + mode=BWD_MODE, + ) + + if DEBUG: + print("flash_attn_triton_amd.py::bwd outputs") + print("dv:", dv, dv.shape) + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) + # --- Assertions --- + assert dq.shape == q.shape, f"[bwd] dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"[bwd] dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"[bwd] dv shape {dv.shape} != v shape {v.shape}" + # delta (softmax_d) : (B, Hq, Sq) + expected_delta_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + delta.shape == expected_delta_shape + ), f"[bwd] delta shape {delta.shape} != {expected_delta_shape}" + return dq, dk, dv, delta + + +def varlen_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + block_table_: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, +): + + if str(q.dtype).startswith("torch.float8"): + raise NotImplementedError( + "FP8 tensors are not supported in the AMD Triton FA2 interface (varlen_fwd). Use the FA3 path instead." + ) + + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in varlen_fwd (expected 0.0)." + ) + if leftpad_k is not None: + raise NotImplementedError( + "leftpad_k is not supported in AMD Triton FA2 varlen_fwd." + ) + if block_table_ is not None: + raise NotImplementedError( + "block_table / paged attention is not supported in AMD Triton FA2 varlen_fwd." + ) + if seqused_k is not None: + raise NotImplementedError( + "seqused_k is not supported in AMD Triton FA2 varlen_fwd." + ) + + if DEBUG: + print() + print("flash_attn_triton_amd.py::varlen_fwd") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) + print("alibi_slopes:", alibi_slopes) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("gen_:", gen_) + out = torch.zeros_like(q) if out is None else out.zero_() + + # Layout and basic info for varlen + layout = "thd" + batch = len(cu_seqlens_q) - 1 + _, nheads_q, _ = q.shape + + if alibi_slopes is not None: + if alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + assert alibi_slopes.is_cuda and alibi_slopes.dim() == 2 + assert alibi_slopes.shape == (batch, nheads_q) + + philox_seed, philox_offset = PHILOX_SEED, PHILOX_OFFSET + rng_state = torch.as_tensor([philox_seed, philox_offset]) + + # Inline checks (subset appropriate for varlen) + assert q.dim() == 3 and k.dim() == 3 and v.dim() == 3 + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert q.dtype == k.dtype == v.dtype + assert out.shape == q.shape + nheads_k = k.shape[1] + assert (nheads_q % nheads_k) == 0 + + # call implementation + if DEBUG: + print("Using Triton implementation") + softmax_lse, sd_mask = attention_forward_prefill_triton_impl( + q, + k, + v, + out, + softmax_scale, + alibi_slopes, + causal, + window_size_left, + window_size_right, + None, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + return_softmax, + USE_EXP2, + None, + None, + None, + ) + + if DEBUG: + print("varlen_fwd outputs") + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None) + # --- Assertions --- + # out: (Total_Q, Hq, D) + assert ( + out.shape == q.shape + ), f"[varlen_fwd] out shape {out.shape} != q shape {q.shape}" + # softmax_lse: (Hq, Total_Q) + expected_lse_shape = (q.shape[1], q.shape[0]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[varlen_fwd] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"[varlen_fwd] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + if return_softmax: + # sd_mask expected: (B, Hq, max_seqlen_q, max_seqlen_k) + assert ( + sd_mask is not None + ), "[varlen_fwd] return_softmax=True but sd_mask is None" + assert sd_mask.dim() == 4, f"[varlen_fwd] sd_mask dim {sd_mask.dim()} != 4" + assert sd_mask.shape[0] == ( + len(cu_seqlens_q) - 1 + ), f"[varlen_fwd] sd_mask batch {sd_mask.shape[0]} != {len(cu_seqlens_q)-1}" + assert ( + sd_mask.shape[1] == q.shape[1] + ), f"[varlen_fwd] sd_mask nheads {sd_mask.shape[1]} != {q.shape[1]}" + else: + assert ( + sd_mask is None + ), "[varlen_fwd] return_softmax=False but sd_mask is not None" + return out, softmax_lse, sd_mask, rng_state + + +def varlen_bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + gen_: Optional[torch.Tensor] = None, + rng_state: Optional[torch.Tensor] = None, +): + if str(q.dtype).startswith("torch.float8"): + raise NotImplementedError( + "FP8 tensors are not supported in the AMD Triton FA2 interface (varlen_bwd). Use the FA3 path instead." + ) + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in varlen_bwd (expected 0.0)." + ) + + if DEBUG: + print() + print("varlen_bwd") + print("dout:", dout, dout.shape) + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("out:", out) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("dq:", dq, dq.shape if dq is not None else None) + print("dk:", dk, dk.shape if dk is not None else None) + print("dv:", dv, dv.shape if dv is not None else None) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) + print("alibi_slopes:", alibi_slopes) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("deterministic:", deterministic) + print("gen_:", gen_) + print("rng_state:", rng_state) + + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() + + # get shape + batch = len(cu_seqlens_q) - 1 + _, nheads_q, _ = q.shape + + # Upstream change: base seeding logic on provided rng_state instead of dropout probability. + if rng_state is not None: + philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() + else: + philox_seed, philox_offset = None, None + + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") + + # call implementation + if DEBUG: + print("Using Triton implementation") + delta = attention_backward_triton_impl( + do=dout, + q=q, + k=k, + v=v, + o=out, + softmax_lse=softmax_lse, + dq=dq, + dk=dk, + dv=dv, + sm_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + layout="thd", + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=None, + seqused_k=None, + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset=philox_offset, + use_exp2=USE_EXP2, + mode=BWD_MODE, + ) + + if DEBUG: + print("varlen_bwd outputs") + print("delta:", delta, delta.shape) + print("dv:", dv, dv.shape) + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) + # --- Assertions --- + assert dq.shape == q.shape, f"[varlen_bwd] dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"[varlen_bwd] dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"[varlen_bwd] dv shape {dv.shape} != v shape {v.shape}" + expected_delta_shape = (q.shape[1], q.shape[0]) # (Hq, Total_Q) + assert ( + delta.shape == expected_delta_shape + ), f"[varlen_bwd] delta shape {delta.shape} != {expected_delta_shape}" + return dq, dk, dv, delta + + +def fwd_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + cache_seqlens: Optional[Union[(int, torch.Tensor)]], + rotary_cos: Optional[torch.Tensor], + rotary_sin: Optional[torch.Tensor], + cache_batch_idx: Optional[torch.Tensor], + cache_leftpad: Optional[torch.Tensor], + block_table: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + out: Optional[torch.Tensor], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + rotary_interleaved: bool, + num_splits: int, +): + + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in fwd_kvcache (expected 0.0)." + ) + if num_splits not in (0, 1): + raise NotImplementedError( + "num_splits > 1 not supported in AMD Triton FA2 fwd_kvcache." + ) + + if DEBUG: + print() + print("flash_attn_triton_amd.py::fwd_kvcache inputs") + print("q:", q, q.shape) + print("k_cache:", k_cache, k_cache.shape) + print("v_cache:", v_cache, v_cache.shape) + print("k:", k, k.shape if k is not None else None) + print("v:", v, v.shape if v is not None else None) + print("cache_seqlens:", cache_seqlens) + print("rotary_cos:", rotary_cos) + print("rotary_sin:", rotary_sin) + print("cache_batch_idx:", cache_batch_idx) + print("cache_leftpad:", cache_leftpad) + print("block_table:", block_table) + print("alibi_slopes:", alibi_slopes) + print("out:", out) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("rotary_interleaved:", rotary_interleaved) + print("num_splits:", num_splits) + + # output + out = torch.zeros_like(q) if out is None else out.zero_() + + # Basic layout info for decode path + layout = "bshd" + max_seqlen_q = q.shape[1] + max_seqlen_k = k_cache.shape[1] + cache_seqlens_tensor = ( + torch.tensor(cache_seqlens, device=q.device) + if isinstance(cache_seqlens, int) + else cache_seqlens + ) + window_left = ( + int(window_size_left.item()) + if isinstance(window_size_left, torch.Tensor) + else window_size_left + ) + window_right = ( + int(window_size_right.item()) + if isinstance(window_size_right, torch.Tensor) + else window_size_right + ) + + k_new = k + v_new = v + + # get shape + batch, _, nheads_q, _ = q.shape + + if alibi_slopes is not None: + if alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + assert alibi_slopes.is_cuda and alibi_slopes.dim() == 2 + assert alibi_slopes.shape == (batch, nheads_q) + + # launch kernel + if DEBUG: + print("Using Triton implementation") + softmax_lse = attention_forward_decode_triton_impl( + q, + k_cache, + v_cache, + k_new, + v_new, + out, + softmax_scale, + causal, + window_left, + window_right, + alibi_slopes, + layout, + cache_seqlens_tensor, + cache_batch_idx, + block_table, + None, + None, + None, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + rotary_interleaved=rotary_interleaved, + ) + + if DEBUG: + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + # --- Assertions --- + assert ( + out.shape == q.shape + ), f"[fwd_kvcache] out shape {out.shape} != q shape {q.shape}" + expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[fwd_kvcache] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"[fwd_kvcache] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + return out, softmax_lse diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py new file mode 100755 index 00000000000..436077a8a7c --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -0,0 +1,608 @@ +import torch +import os +from typing import Optional, Union, Tuple +from .fwd_prefill import attention_forward_prefill_triton_impl +from .fwd_decode import attention_forward_decode_triton_impl +from .bwd import attention_backward_triton_impl +from .utils import DEBUG, USE_EXP2, BWD_MODE, PHILOX_SEED, PHILOX_OFFSET, is_fp8 + + +def fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + qv: Optional[torch.Tensor], + out: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + cu_seqlens_k_new: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + page_table: Optional[torch.Tensor], + kv_batch_idx: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + rotary_cos: Optional[torch.Tensor], + rotary_sin: Optional[torch.Tensor], + seqlens_rotary: Optional[torch.Tensor], + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + attention_chunk: int, + softcap: float, + rotary_interleaved: bool, + scheduler_metadata=None, + num_splits: int = 1, + pack_gqa=None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Flash Attention v3 forward pass compatible interface for AMD Triton implementation. + + This function maps v3 parameters to the existing AMD Triton implementation. + """ + + if DEBUG: + print() + print("interface_fa_v3.py::fwd inputs") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("k_new:", k_new, k_new.shape if k_new is not None else None) + print("v_new:", v_new, v_new.shape if v_new is not None else None) + print("qv:", qv, qv.shape if qv is not None else None) + print("out:", out, out.shape if out is not None else None) + print( + "cu_seqlens_q:", + cu_seqlens_q, + cu_seqlens_q.shape if cu_seqlens_q is not None else None, + ) + print( + "cu_seqlens_k:", + cu_seqlens_k, + cu_seqlens_k.shape if cu_seqlens_k is not None else None, + ) + print( + "cu_seqlens_k_new:", + cu_seqlens_k_new, + cu_seqlens_k_new.shape if cu_seqlens_k_new is not None else None, + ) + print( + "seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None + ) + print( + "seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None + ) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print( + "page_table:", + page_table, + page_table.shape if page_table is not None else None, + ) + print( + "kv_batch_idx:", + kv_batch_idx, + kv_batch_idx.shape if kv_batch_idx is not None else None, + ) + print( + "leftpad_k:", leftpad_k, leftpad_k.shape if leftpad_k is not None else None + ) + print( + "rotary_cos:", + rotary_cos, + rotary_cos.shape if rotary_cos is not None else None, + ) + print( + "rotary_sin:", + rotary_sin, + rotary_sin.shape if rotary_sin is not None else None, + ) + print( + "seqlens_rotary:", + seqlens_rotary, + seqlens_rotary.shape if seqlens_rotary is not None else None, + ) + print( + "q_descale:", q_descale, q_descale.shape if q_descale is not None else None + ) + print( + "k_descale:", k_descale, k_descale.shape if k_descale is not None else None + ) + print( + "v_descale:", v_descale, v_descale.shape if v_descale is not None else None + ) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("attention_chunk:", attention_chunk) + print("softcap:", softcap) + print("rotary_interleaved:", rotary_interleaved) + print("scheduler_metadata:", scheduler_metadata) + print("num_splits:", num_splits) + print("pack_gqa:", pack_gqa) + print("sm_margin:", sm_margin) + + # Handle qv packed input + if qv is not None: + raise NotImplementedError( + "QV packed input is not yet supported in the AMD Triton backend" + ) + + # Handle softcap + if softcap != 0.0: + raise NotImplementedError( + f"Softcap is not yet supported in the AMD Triton backend (got softcap={softcap}, expected 0.0)" + ) + + # Handle attention_chunk + if attention_chunk != 0 and attention_chunk != 1: + raise NotImplementedError( + f"attention_chunk is not yet supported in the AMD Triton backend (got attention_chunk={attention_chunk})" + ) + + # Handle scheduler metadata + if scheduler_metadata is not None: + raise NotImplementedError( + "Scheduler metadata is not yet supported in the AMD Triton backend" + ) + + # Handle pack_gqa + if pack_gqa is not None and pack_gqa is not False: + raise NotImplementedError( + f"pack_gqa is not yet supported in the AMD Triton backend (got pack_gqa={pack_gqa})" + ) + + # Handle num_splits + if num_splits != 1: + raise NotImplementedError( + f"Split attention (num_splits > 1) is not yet supported in the AMD Triton backend (got num_splits={num_splits})" + ) + + # Handle sm_margin + if sm_margin != 0: + raise NotImplementedError( + f"sm_margin is not yet supported in the AMD Triton backend (got sm_margin={sm_margin}, expected 0)" + ) + + # Handle leftpad_k + if leftpad_k is not None: + raise NotImplementedError( + "Left padding (leftpad_k) is not yet supported in the AMD Triton backend" + ) + + # Handle cu_seqlens_k_new + if cu_seqlens_k_new is not None: + raise NotImplementedError( + "cu_seqlens_k_new is not yet supported in the AMD Triton backend" + ) + + # if seqlens_rotary is not None: + # raise NotImplementedError("seqlens_rotary is not yet supported in the AMD Triton backend") + + # establish layout / varlen & max seq lens + if cu_seqlens_q is not None: + if len(q.shape) != 3: + raise ValueError( + f"cu_seqlens_q provided but q has shape {q.shape}, expected 3D tensor for varlen" + ) + layout = "thd" + cu_seqlens_q_local = cu_seqlens_q + max_seqlens_q_local = max_seqlen_q + if cu_seqlens_k is not None: + cu_seqlens_k_local = cu_seqlens_k + max_seqlens_k_local = max_seqlen_k + else: + cu_seqlens_k_local = None + max_seqlens_k_local = k.shape[1] if len(k.shape) == 4 else max_seqlen_k + else: + layout = "bshd" + cu_seqlens_q_local = None + cu_seqlens_k_local = None + max_seqlens_q_local = q.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlens_k_local = k.shape[1] if max_seqlen_k is None else max_seqlen_k + + # Now determine if we should use decode or prefill kernel + # Decode kernel should be used for KV cache scenarios where: + # 1. k_new/v_new are provided - incremental KV cache update (primary KV cache indicator) + # 2. kv_batch_idx is provided - KV cache batch indexing (primary KV cache indicator) + # 3. seqused_k without seqused_q - indicates KV cache fill levels (not varlen masking) + # Note: In varlen, both seqused_q and seqused_k are used for sequence masking + # In KV cache, only seqused_k is used to track cache fill levels + # Detect KV cache scenarios: + # - Clear KV cache indicators (k_new, v_new, kv_batch_idx) + # - OR seqused_k without seqused_q (KV cache fill tracking, not varlen masking) + use_decode = ( + k_new is not None # Have new KV to append (KV cache indicator) + or v_new is not None # Have new KV to append (KV cache indicator) + or kv_batch_idx is not None # Have KV cache batch indexing (KV cache indicator) + or ( + seqused_k is not None and seqused_q is None + ) # KV cache fill levels (not varlen) + ) + + # Check for unsupported features with decode kernel + if use_decode: + if layout == "thd": + raise NotImplementedError( + "Varlen is not yet supported with the decode kernel in the AMD Triton backend" + ) + if kv_batch_idx is not None: + raise NotImplementedError( + "kv_batch_idx is not yet supported with the decode kernel in the AMD Triton backend" + ) + + if out is None: + out_dtype = torch.float32 if is_fp8(q) else q.dtype + if layout == "bshd": + out = torch.zeros( + q.shape[0], + q.shape[1], + q.shape[2], + v.shape[-1], + dtype=out_dtype, + device=q.device, + ) + elif layout == "thd": + out = torch.zeros( + q.shape[0], q.shape[1], v.shape[-1], dtype=out_dtype, device=q.device + ) + else: + raise ValueError( + f"Unsupported layout: {layout}. Only 'bshd' and 'thd' layouts are supported." + ) + else: + out = out.zero_() + + if is_fp8(q): + if (q_descale is None) or (k_descale is None) or (v_descale is None): + import warnings + + warnings.warn( + "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", + UserWarning, + ) + else: + # Enforce exact expected shapes; no reshaping or normalization. + if layout == "bshd": + expected_batch = q.shape[0] + expected_q_heads = q.shape[2] + expected_kv_heads = k.shape[2] + else: # thd layout + expected_batch = ( + (len(cu_seqlens_q_local) - 1) + if cu_seqlens_q_local is not None + else 1 + ) + expected_q_heads = q.shape[1] + expected_kv_heads = k.shape[1] + + assert ( + q_descale.dim() == 2 + and q_descale.shape[0] == expected_batch + and q_descale.shape[1] == expected_kv_heads + ), f"q_descale expected shape ({expected_batch}, {expected_kv_heads}) got {tuple(q_descale.shape)}" + assert ( + k_descale.dim() == 2 + and k_descale.shape[0] == expected_batch + and k_descale.shape[1] == expected_kv_heads + ), f"k_descale expected shape ({expected_batch}, {expected_kv_heads}) got {tuple(k_descale.shape)}" + assert ( + v_descale.dim() == 2 + and v_descale.shape[0] == expected_batch + and v_descale.shape[1] == expected_kv_heads + ), f"v_descale expected shape ({expected_batch}, {expected_kv_heads}) got {tuple(v_descale.shape)}" + + # Handle causal mask + causal_flag = bool(causal) + + # Handle alibi slopes + alibi_slopes = None + + # Handle dropout + dropout_p = 0.0 + return_softmax = False + philox_seed = PHILOX_SEED + philox_offset = PHILOX_OFFSET + + # Call implementation + if DEBUG: + print("Using Triton implementation") + + if use_decode: + if DEBUG: + print( + f"Using Decode Triton implementation (cache_seqlens={seqused_k is not None}, k_new={k_new is not None}, v_new={v_new is not None}, kv_batch_idx={kv_batch_idx is not None})" + ) + + softmax_lse = attention_forward_decode_triton_impl( + q, + k, + v, + k_new, + v_new, + out, + softmax_scale, + causal_flag, + window_size_left, + window_size_right, + alibi_slopes, + layout, + seqused_k, + kv_batch_idx, + page_table, + q_descale, + k_descale, + v_descale, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + rotary_interleaved=rotary_interleaved, + seqlens_rotary=seqlens_rotary, + ) + else: + if DEBUG: + print("Using Prefill Triton implementation") + softmax_lse, _ = attention_forward_prefill_triton_impl( + q, + k, + v, + out, + softmax_scale, + alibi_slopes, + causal_flag, + window_size_left, + window_size_right, + None, + layout, + cu_seqlens_q_local, + cu_seqlens_k_local, + max_seqlens_q_local, + max_seqlens_k_local, + dropout_p, + philox_seed, + philox_offset, + return_softmax, + USE_EXP2, + q_descale, + k_descale, + v_descale, + seqused_q, + seqused_k, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + rotary_interleaved=rotary_interleaved, + seqlens_rotary=seqlens_rotary, + ) + + if DEBUG: + print("interface_fa_v3.py::fwd outputs") + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + + # Return format compatible with v3 + # V3 returns (out, softmax_lse, *rest) where rest can be empty or contain additional outputs + return out, softmax_lse + + +def bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Flash Attention v3 backward pass compatible interface for AMD Triton implementation. + + This function maps v3 parameters to the existing AMD Triton implementation. + """ + + if DEBUG: + print() + print("interface_fa_v3.py::bwd inputs") + print("dout:", dout, dout.shape) + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("dq:", dq, dq.shape if dq is not None else None) + print("dk:", dk, dk.shape if dk is not None else None) + print("dv:", dv, dv.shape if dv is not None else None) + print( + "cu_seqlens_q:", + cu_seqlens_q, + cu_seqlens_q.shape if cu_seqlens_q is not None else None, + ) + print( + "cu_seqlens_k:", + cu_seqlens_k, + cu_seqlens_k.shape if cu_seqlens_k is not None else None, + ) + print( + "seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None + ) + print( + "seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None + ) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("deterministic:", deterministic) + print("sm_margin:", sm_margin) + + # Check for unsupported features in backward pass + + # Handle softcap + if softcap != 0.0: + raise NotImplementedError( + f"Softcap is not yet supported in the AMD Triton backend backward pass (got softcap={softcap}, expected 0.0)" + ) + + # Handle sm_margin + if sm_margin != 0: + raise NotImplementedError( + f"sm_margin is not yet supported in the AMD Triton backend backward pass (got sm_margin={sm_margin}, expected 0)" + ) + + # Initialize gradient tensors if not provided + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() + + # Determine layout based on cu_seqlens + if cu_seqlens_q is not None and cu_seqlens_k is not None: + # Variable length sequence mode + layout = "thd" + batch = len(cu_seqlens_q) - 1 + _, nheads_q, _ = q.shape + else: + # Regular batch mode + layout = "bshd" + batch, _, nheads_q, _ = q.shape + max_seqlen_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k + + # V3 backward doesn't have dropout or alibi slopes + dropout_p = 0.0 + philox_seed, philox_offset = None, None + alibi_slopes = None + + # Call implementation + if DEBUG: + print("Using Triton implementation (unified backward dispatcher)") + delta = attention_backward_triton_impl( + do=dout, + q=q, + k=k, + v=v, + o=out, + softmax_lse=softmax_lse, + dq=dq, + dk=dk, + dv=dv, + sm_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + layout=layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset=philox_offset, + use_exp2=USE_EXP2, + mode=BWD_MODE, + ) + + if DEBUG: + print("interface_fa_v3.py::bwd outputs") + print("dq:", dq, dq.shape) + print("dk:", dk, dk.shape) + print("dv:", dv, dv.shape) + print("delta:", delta, delta.shape if delta is not None else None) + + # V3 expects (dq, dk, dv, softmax_d, *rest) + # delta is the softmax_d in this case + return dq, dk, dv, delta + + +def fwd_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + Combine partial outputs from split attention computation. + + This is used when num_splits > 1 to combine the partial results. + + Args: + out_partial: Partial output tensor from split computation + lse_partial: Partial log-sum-exp tensor + out: Optional output tensor to write to + out_dtype: Optional dtype for output + + Returns: + Combined output tensor + """ + raise NotImplementedError( + "fwd_combine is not yet implemented in the AMD Triton backend" + ) + + +def get_scheduler_metadata( + batch_size: int, + max_seqlen_q: int, + max_seqlen_k: int, + num_heads_q: int, + num_heads_kv: int, + headdim: int, + headdim_v: int, + qkv_dtype: torch.dtype, + cache_seqlens: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new: int = 0, + causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + has_softcap: bool = False, + num_splits: int = 0, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +): + """ + Get scheduler metadata for optimized kernel selection. + + This function is used to precompute metadata for kernel scheduling in FA3. + The AMD Triton backend currently doesn't use scheduler metadata, so this + raises an error. + + Args: + Various attention parameters used for scheduling decisions + + Returns: + None - scheduler metadata is not used in AMD Triton backend + """ + raise NotImplementedError( + "get_scheduler_metadata is not supported in the AMD Triton backend yet." + ) diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 44502785a35..71ed1c1c2de 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -7,147 +7,57 @@ import triton import triton.language as tl import numpy as np -from typing import Literal, Optional +from typing import Literal, Optional, Union, Tuple # ------------------------------- # Gloabl Variables # ------------------------------- -AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes') +AUTOTUNE = os.environ.get("FLASH_ATTENTION_TRITON_AMD_AUTOTUNE", "0").lower() in ( + "1", + "true", + "yes", +) if AUTOTUNE: os.environ["TRITON_PRINT_AUTOTUNING"] = "1" -DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes') -USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') -PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes') -USE_SINGLE_BWD_KERNEL = os.environ.get('USE_SINGLE_BWD_KERNEL', '0').lower() in ('1', 'true', 'yes') +DEBUG = os.environ.get("FLASH_ATTENTION_TRITON_AMD_DEBUG", "0").lower() in ( + "1", + "true", + "yes", +) +PERF = os.environ.get("FLASH_ATTENTION_TRITON_AMD_PERF", "0").lower() in ( + "1", + "true", + "yes", +) +USE_SINGLE_BWD_KERNEL = os.environ.get("USE_SINGLE_BWD_KERNEL", "0").lower() in ( + "1", + "true", + "yes", +) USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" -USE_TRITON_INTERPRET = os.environ.get('TRITON_INTERPRET', '0').lower() in ('1', 'true', 'yes') -DEBUG_TRITON = os.environ.get('DEBUG_TRITON', '0').lower() in ('1', 'true', 'yes') and USE_TRITON_INTERPRET -DEBUG_TRITON_DETAIL = os.environ.get('DEBUG_TRITON_DETAIL', '0').lower() in ('1', 'true', 'yes') and USE_TRITON_INTERPRET -if USE_TRITON_ROCM: # TODO remove this +USE_TRITON_INTERPRET = os.environ.get("TRITON_INTERPRET", "0").lower() in ( + "1", + "true", + "yes", +) +DEBUG_TRITON = ( + os.environ.get("DEBUG_TRITON", "0").lower() in ("1", "true", "yes") + and USE_TRITON_INTERPRET +) +DEBUG_TRITON_DETAIL = ( + os.environ.get("DEBUG_TRITON_DETAIL", "0").lower() in ("1", "true", "yes") + and USE_TRITON_INTERPRET +) +if USE_TRITON_ROCM: # TODO remove this random.seed(42) +BWD_MODE = os.environ.get("BWD_MODE", "fused_no_atomics").lower() DROPOUT_USE_PYTORCH = False DROPOUT_DUMP = False +USE_EXP2 = True +PHILOX_SEED = 0x1BF58 +PHILOX_OFFSET = 0x1D4B49 -# ------------------------------- -# Metadata -# ------------------------------- -class MetaData(): - cu_seqlens_q: Optional[torch.Tensor] = None - cu_seqlens_k: Optional[torch.Tensor] = None - max_seqlens_q: int = 0 - max_seqlens_k: int = 0 - bias: Optional[torch.Tensor] = None - alibi_slopes: Optional[torch.Tensor] = None - causal: bool = False - num_contexts = 0 - varlen: bool = False - layout: Optional[Literal["bshd", "bhsd", "thd"]] = None - cache_seqlens: Optional[torch.Tensor] = None - cache_batch_idx = None - packing: Optional[bool] = None - return_softmax: bool = False - dropout_p: float = 0.0 - philox_seed: Optional[int] = None - philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. - # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. - rotary_sin: Optional[torch.Tensor] = None - rotary_cos: Optional[torch.Tensor] = None - rotary_interleaved: bool = False - rotary_conjunction: bool = False - window_size_left: int = -1 - window_size_right: int = -1 - - - def __repr__(self) -> str: - return (f"MetaData(\n" - f" sm_scale={self.sm_scale},\n" - f" cu_seqlens_q={self.cu_seqlens_q},\n" - f" cu_seqlens_k={self.cu_seqlens_k},\n" - f" max_seqlens_q={self.max_seqlens_q},\n" - f" max_seqlens_k={self.max_seqlens_k},\n" - f" bias={self.bias},\n" - f" alibi_slopes={self.alibi_slopes},\n" - f" causal={self.causal},\n" - f" num_contexts={self.num_contexts},\n" - f" varlen={self.varlen},\n" - f" layout={self.layout},\n" - f" cache_seqlens={self.cache_seqlens},\n" - f" cache_batch_idx={self.cache_batch_idx},\n" - f" dropout_p={self.dropout_p},\n" - f" return_softmax={self.return_softmax}\n" - f" window_size_left={self.window_size_left},\n" - f" window_size_right={self.window_size_right},\n" - f")") - - def __init__(self, sm_scale=1.0): - self.sm_scale = sm_scale - - def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k): - self.varlen = True - self.layout = 'thd' - self.cu_seqlens_q = cu_seqlens_q - self.cu_seqlens_k = cu_seqlens_k - self.max_seqlens_q = max_seqlen_q - self.max_seqlens_k = max_seqlen_k - - # Without "varlen", there should still be one sequence. - assert len(cu_seqlens_q) >= 2 - assert len(cu_seqlens_q) == len(cu_seqlens_k) - - def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): - assert bias.is_cuda - assert bias.dim() == 4 - assert bias.shape[0] == 1 - assert bias.shape[2:] == (seqlen_q, seqlen_k) - self.bias = bias - - def need_alibi(self, alibi_slopes, batch, nheads): - assert alibi_slopes.is_cuda - assert alibi_slopes.dim() == 2 - assert alibi_slopes.shape[0] == batch - assert alibi_slopes.shape[1] == nheads - self.alibi_slopes = alibi_slopes - - def need_causal(self, causal): - self.causal = causal - - def need_rotary(self, sin, cos, rotary_interleaved, rotary_conjunction=False): - self.rotary_sin = sin - self.rotary_cos = cos - self.rotary_interleaved = rotary_interleaved - self.rotary_conjunction = rotary_conjunction - - def need_dropout(self, dropout_p, return_softmax): - self.dropout_p = dropout_p - self.return_softmax = return_softmax - self.philox_seed, self.philox_offset = 0x1BF58, 0x1D4B49 - - def check_args(self, q, k, v, o): - assert q.dim() == k.dim() and q.dim() == v.dim() - - batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) - if self.varlen: - assert q.dim() == 3 - assert self.cu_seqlens_q is not None - assert self.cu_seqlens_k is not None - assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) - # TODO: Remove once bias is supported with varlen - assert self.bias is None - # assert not self.return_softmax - else: - assert q.dim() == 4 - assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 - assert self.cu_seqlens_q is None and self.cu_seqlens_k is None - # assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] # and q.shape[-1] == v.shape[-1] - # TODO: Change assert if we support qkl f8 and v f16 - assert q.dtype == k.dtype and q.dtype == v.dtype - assert o.shape[:-1] == q.shape[:-1] and o.shape[-1] == v.shape[-1] - assert (nheads_q % nheads_k) == 0 - assert self.layout is not None - assert self.layout == 'thd' or not self.varlen - # ------------------------------- # Input Helper # ------------------------------- @@ -155,14 +65,17 @@ def random_seqlens_composition(SEQ_LEN, BATCH): # generate a random composition of N into Z positive parts. idx = torch.randperm(SEQ_LEN - 1)[: BATCH - 1] + 1 idx, _ = torch.sort(idx) - breakpoints = torch.cat([ - torch.tensor([0], dtype=torch.long), - idx, - torch.tensor([SEQ_LEN], dtype=torch.long), - ]) + breakpoints = torch.cat( + [ + torch.tensor([0], dtype=torch.long), + idx, + torch.tensor([SEQ_LEN], dtype=torch.long), + ] + ) seqlens = (breakpoints[1:] - breakpoints[:-1]).to(torch.int32) return seqlens + def generate_varlen_tensor( total_seqlen: int, num_heads: int, @@ -171,7 +84,7 @@ def generate_varlen_tensor( equal_seqlens: bool = False, device: str = "cuda", dtype: torch.dtype = torch.float16, - mode: Literal["random", "ones", "incremental", "identity"] = "random" + mode: Literal["random", "ones", "incremental", "identity"] = "random", ): if DEBUG: print("total_seqlen", total_seqlen) @@ -186,23 +99,28 @@ def generate_varlen_tensor( # get valid batch_size if batch_size is None: - valid_batch_sizes = [bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen] + valid_batch_sizes = [ + bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen + ] batch_size = random.choice(valid_batch_sizes) - + # get seqlens if equal_seqlens: seqlens = torch.full( - (batch_size,), - total_seqlen // batch_size, - dtype=torch.int32, - device=device + (batch_size,), total_seqlen // batch_size, dtype=torch.int32, device=device ) seqlens[-1] += total_seqlen % batch_size else: seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) # create cumulative sequence lengths - cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)]).to(torch.int32).to(device=device) + cu_seqlens = ( + torch.cat( + [torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)] + ) + .to(torch.int32) + .to(device=device) + ) max_seqlen = torch.max(seqlens).to(torch.int32).item() # create varlen tensor based on mode @@ -210,8 +128,8 @@ def generate_varlen_tensor( x = torch.zeros(total_seqlen, num_heads, head_size, dtype=dtype, device=device) for i in range(batch_size): start = cu_seqlens[i].item() - end = cu_seqlens[i+1].item() - length = end - start + end = cu_seqlens[i + 1].item() + length = end - start x[start:end, :, :] = ( torch.arange(length, dtype=dtype, device=device) @@ -223,14 +141,16 @@ def generate_varlen_tensor( # for each batch, create identity pattern within that batch's sequence for i in range(batch_size): start = cu_seqlens[i].item() - end = cu_seqlens[i+1].item() + end = cu_seqlens[i + 1].item() length = end - start - + # create identity pattern for positions within this batch for pos in range(min(length, head_size)): x[start + pos, :, pos] = 1.0 elif mode == "random": - x = torch.randn((total_seqlen, num_heads, head_size), dtype=dtype, device=device) + x = torch.randn( + (total_seqlen, num_heads, head_size), dtype=dtype, device=device + ) elif mode == "ones": x = torch.ones((total_seqlen, num_heads, head_size), dtype=dtype, device=device) else: @@ -238,14 +158,25 @@ def generate_varlen_tensor( if is_fp8_dtype: # cast to fp8 - x, descale_x = cast_to_fp8(x, og_fp8_dtype, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + x, descale_x = cast_to_fp8( + x, og_fp8_dtype, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) x.requires_grad_() return x, cu_seqlens, max_seqlen, descale_x else: x.requires_grad_() return x, cu_seqlens, max_seqlen -def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", mode: Literal["random", "ones", "incremental", "identity"] = "random"): + +def generate_bshd_tensor( + BATCH, + SEQ_LEN, + NUM_HEADS, + D_HEAD, + dtype: torch.dtype = torch.float16, + device="cuda", + mode: Literal["random", "ones", "incremental", "identity"] = "random", +): # save fp8 type is_fp8_dtype = is_dtype_fp8(dtype) if is_fp8_dtype: @@ -255,7 +186,12 @@ def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtype = # gen tensor based on mode tensor_shape = (BATCH, SEQ_LEN, NUM_HEADS, D_HEAD) if mode == "incremental": - x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, SEQ_LEN, 1, 1).expand(*tensor_shape).contiguous() + x = ( + torch.arange(SEQ_LEN, dtype=dtype, device=device) + .view(1, SEQ_LEN, 1, 1) + .expand(*tensor_shape) + .contiguous() + ) elif mode == "identity": x = torch.zeros(tensor_shape, dtype=dtype, device=device) # create identity pattern: position i has value 1 at dimension i @@ -267,7 +203,7 @@ def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtype = x = torch.ones(tensor_shape, dtype=dtype, device=device) else: raise ValueError(f"Unkown mode {mode}") - + if is_fp8_dtype: # cast to fp8 x, descale_x = cast_to_fp8(x, og_fp8_dtype, "bshd") @@ -277,17 +213,31 @@ def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtype = x.requires_grad_() return x -def generate_bhsd_tensor(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", mode: Literal["random", "ones", "incremental", "identity"] = "random"): + +def generate_bhsd_tensor( + BATCH, + NUM_HEADS, + SEQ_LEN, + D_HEAD, + dtype: torch.dtype = torch.float16, + device="cuda", + mode: Literal["random", "ones", "incremental", "identity"] = "random", +): # save fp8 type is_fp8_dtype = is_dtype_fp8(dtype) if is_fp8_dtype: og_fp8_dtype = dtype dtype = torch.float32 - + # gen tensor based on mode tensor_shape = (BATCH, NUM_HEADS, SEQ_LEN, D_HEAD) if mode == "incremental": - x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, 1, SEQ_LEN, 1).expand(*tensor_shape).contiguous() + x = ( + torch.arange(SEQ_LEN, dtype=dtype, device=device) + .view(1, 1, SEQ_LEN, 1) + .expand(*tensor_shape) + .contiguous() + ) elif mode == "identity": x = torch.zeros(tensor_shape, dtype=dtype, device=device) # create identity pattern: position i has value 1 at dimension i @@ -299,14 +249,23 @@ def generate_bhsd_tensor(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype: torch.dtype = x = torch.ones(tensor_shape, dtype=dtype, device=device) else: raise ValueError(f"Unkown mode {mode}") - + if is_fp8_dtype: raise ValueError("fp8 not supported for bhsd yet") else: x.requires_grad_() return x -def generate_bshd_qkv_packed(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", DEBUG_INPUT=False): + +def generate_bshd_qkv_packed( + BATCH, + SEQ_LEN, + NUM_HEADS, + D_HEAD, + dtype: torch.dtype = torch.float16, + device="cuda", + DEBUG_INPUT=False, +): """Generate QKV packed tensor with shape (BATCH, SEQ_LEN, 3, NUM_HEADS, D_HEAD)""" # save fp8 type is_fp8_dtype = is_dtype_fp8(dtype) @@ -317,10 +276,15 @@ def generate_bshd_qkv_packed(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dty # gen tensor tensor_shape = (BATCH, SEQ_LEN, 3, NUM_HEADS, D_HEAD) if DEBUG_INPUT: - x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, SEQ_LEN, 1, 1, 1).expand(*tensor_shape).contiguous() + x = ( + torch.arange(SEQ_LEN, dtype=dtype, device=device) + .view(1, SEQ_LEN, 1, 1, 1) + .expand(*tensor_shape) + .contiguous() + ) else: x = torch.randn(tensor_shape, dtype=dtype, device=device) - + if is_fp8_dtype: # cast to fp8 - need to handle the packed dimension raise NotImplementedError("FP8 not supported for QKV packing yet") @@ -329,7 +293,15 @@ def generate_bshd_qkv_packed(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dty return x -def generate_bshd_kv_packed(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", DEBUG_INPUT=False): +def generate_bshd_kv_packed( + BATCH, + SEQ_LEN, + NUM_HEADS, + D_HEAD, + dtype: torch.dtype = torch.float16, + device="cuda", + DEBUG_INPUT=False, +): """Generate KV packed tensor with shape (BATCH, SEQ_LEN, 2, NUM_HEADS, D_HEAD)""" # save fp8 type is_fp8_dtype = is_dtype_fp8(dtype) @@ -340,10 +312,15 @@ def generate_bshd_kv_packed(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtyp # gen tensor tensor_shape = (BATCH, SEQ_LEN, 2, NUM_HEADS, D_HEAD) if DEBUG_INPUT: - x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, SEQ_LEN, 1, 1, 1).expand(*tensor_shape).contiguous() + x = ( + torch.arange(SEQ_LEN, dtype=dtype, device=device) + .view(1, SEQ_LEN, 1, 1, 1) + .expand(*tensor_shape) + .contiguous() + ) else: x = torch.randn(tensor_shape, dtype=dtype, device=device) - + if is_fp8_dtype: # cast to fp8 - need to handle the packed dimension raise NotImplementedError("FP8 not supported for KV packing yet") @@ -352,21 +329,34 @@ def generate_bshd_kv_packed(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype: torch.dtyp return x -def generate_bhsd_qkv_packed(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", DEBUG_INPUT=False): +def generate_bhsd_qkv_packed( + BATCH, + NUM_HEADS, + SEQ_LEN, + D_HEAD, + dtype: torch.dtype = torch.float16, + device="cuda", + DEBUG_INPUT=False, +): """Generate QKV packed tensor with shape (BATCH, 3, NUM_HEADS, SEQ_LEN, D_HEAD)""" # save fp8 type is_fp8_dtype = is_dtype_fp8(dtype) if is_fp8_dtype: og_fp8_dtype = dtype dtype = torch.float32 - + # gen tensor tensor_shape = (BATCH, 3, NUM_HEADS, SEQ_LEN, D_HEAD) if DEBUG_INPUT: - x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, 1, 1, SEQ_LEN, 1).expand(*tensor_shape).contiguous() + x = ( + torch.arange(SEQ_LEN, dtype=dtype, device=device) + .view(1, 1, 1, SEQ_LEN, 1) + .expand(*tensor_shape) + .contiguous() + ) else: x = torch.randn(tensor_shape, dtype=dtype, device=device) - + if is_fp8_dtype: # cast to fp8 - need to handle the packed dimension raise NotImplementedError("FP8 not supported for QKV packing yet") @@ -375,21 +365,34 @@ def generate_bhsd_qkv_packed(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype: torch.dty return x -def generate_bhsd_kv_packed(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype: torch.dtype = torch.float16, device="cuda", DEBUG_INPUT=False): +def generate_bhsd_kv_packed( + BATCH, + NUM_HEADS, + SEQ_LEN, + D_HEAD, + dtype: torch.dtype = torch.float16, + device="cuda", + DEBUG_INPUT=False, +): """Generate KV packed tensor with shape (BATCH, 2, NUM_HEADS, SEQ_LEN, D_HEAD)""" # save fp8 type is_fp8_dtype = is_dtype_fp8(dtype) if is_fp8_dtype: og_fp8_dtype = dtype dtype = torch.float32 - + # gen tensor tensor_shape = (BATCH, 2, NUM_HEADS, SEQ_LEN, D_HEAD) if DEBUG_INPUT: - x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, 1, 1, SEQ_LEN, 1).expand(*tensor_shape).contiguous() + x = ( + torch.arange(SEQ_LEN, dtype=dtype, device=device) + .view(1, 1, 1, SEQ_LEN, 1) + .expand(*tensor_shape) + .contiguous() + ) else: x = torch.randn(tensor_shape, dtype=dtype, device=device) - + if is_fp8_dtype: # cast to fp8 - need to handle the packed dimension raise NotImplementedError("FP8 not supported for KV packing yet") @@ -406,7 +409,7 @@ def generate_varlen_qkv_packed( equal_seqlens: bool = False, device: str = "cuda", dtype: torch.dtype = torch.float16, - DEBUG_INPUT: bool = False + DEBUG_INPUT: bool = False, ): """Generate varlen QKV packed tensor with shape (total_seqlen, 3, num_heads, head_size)""" if DEBUG: @@ -423,31 +426,38 @@ def generate_varlen_qkv_packed( # get valid batch_size if batch_size is None: - valid_batch_sizes = [bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen] + valid_batch_sizes = [ + bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen + ] batch_size = random.choice(valid_batch_sizes) - + # get seqlens if equal_seqlens: seqlens = torch.full( - (batch_size,), - total_seqlen // batch_size, - dtype=torch.int32, - device=device + (batch_size,), total_seqlen // batch_size, dtype=torch.int32, device=device ) seqlens[-1] += total_seqlen % batch_size else: seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) # create cumulative sequence lengths - cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)]).to(torch.int32).to(device=device) + cu_seqlens = ( + torch.cat( + [torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)] + ) + .to(torch.int32) + .to(device=device) + ) max_seqlen = torch.max(seqlens).to(torch.int32).item() # create varlen qkv packed tensor if DEBUG_INPUT: - x = torch.zeros(total_seqlen, 3, num_heads, head_size, dtype=dtype, device=device) + x = torch.zeros( + total_seqlen, 3, num_heads, head_size, dtype=dtype, device=device + ) for i in range(batch_size): start = cu_seqlens[i].item() - end = cu_seqlens[i+1].item() + end = cu_seqlens[i + 1].item() length = end - start x[start:end, :, :, :] = ( @@ -456,7 +466,9 @@ def generate_varlen_qkv_packed( .expand(length, 3, num_heads, head_size) ) else: - x = torch.randn((total_seqlen, 3, num_heads, head_size), dtype=dtype, device=device) + x = torch.randn( + (total_seqlen, 3, num_heads, head_size), dtype=dtype, device=device + ) if is_fp8_dtype: # cast to fp8 - need to handle the packed dimension @@ -474,7 +486,7 @@ def generate_varlen_kv_packed( equal_seqlens: bool = False, device: str = "cuda", dtype: torch.dtype = torch.float16, - DEBUG_INPUT: bool = False + DEBUG_INPUT: bool = False, ): """Generate varlen KV packed tensor with shape (total_seqlen, 2, num_heads, head_size)""" if DEBUG: @@ -491,31 +503,38 @@ def generate_varlen_kv_packed( # get valid batch_size if batch_size is None: - valid_batch_sizes = [bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen] + valid_batch_sizes = [ + bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen + ] batch_size = random.choice(valid_batch_sizes) - + # get seqlens if equal_seqlens: seqlens = torch.full( - (batch_size,), - total_seqlen // batch_size, - dtype=torch.int32, - device=device + (batch_size,), total_seqlen // batch_size, dtype=torch.int32, device=device ) seqlens[-1] += total_seqlen % batch_size else: seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) # create cumulative sequence lengths - cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)]).to(torch.int32).to(device=device) + cu_seqlens = ( + torch.cat( + [torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)] + ) + .to(torch.int32) + .to(device=device) + ) max_seqlen = torch.max(seqlens).to(torch.int32).item() # create varlen kv packed tensor if DEBUG_INPUT: - x = torch.zeros(total_seqlen, 2, num_heads, head_size, dtype=dtype, device=device) + x = torch.zeros( + total_seqlen, 2, num_heads, head_size, dtype=dtype, device=device + ) for i in range(batch_size): start = cu_seqlens[i].item() - end = cu_seqlens[i+1].item() + end = cu_seqlens[i + 1].item() length = end - start x[start:end, :, :, :] = ( @@ -524,7 +543,9 @@ def generate_varlen_kv_packed( .expand(length, 2, num_heads, head_size) ) else: - x = torch.randn((total_seqlen, 2, num_heads, head_size), dtype=dtype, device=device) + x = torch.randn( + (total_seqlen, 2, num_heads, head_size), dtype=dtype, device=device + ) if is_fp8_dtype: # cast to fp8 - need to handle the packed dimension @@ -533,7 +554,6 @@ def generate_varlen_kv_packed( x.requires_grad_() return x, cu_seqlens, max_seqlen -# Replace the existing input_helper function in utils.py with this updated version def input_helper( BATCH: int, @@ -547,7 +567,7 @@ def input_helper( dtype: torch.dtype, layout: Literal["bshd", "bhsd", "thd"], packing: Optional[Literal["kv", "qkv"]] = None, - device: Literal["cpu", "cuda"] = "cuda" + device: Literal["cpu", "cuda"] = "cuda", ): torch.manual_seed(20) is_fp8_dtype = is_dtype_fp8(dtype) @@ -557,136 +577,284 @@ def input_helper( TOTAL_SEQLENS_Q = BATCH * N_CTX_Q TOTAL_SEQLENS_K = BATCH * N_CTX_K equal_seqlens = False - + # deal with packing if packing is None: # gen tensors if is_fp8_dtype: - q, cu_seqlens_q, max_seqlen_q, descale_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) - k, cu_seqlens_k, max_seqlen_k, descale_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) - v, _, _, descale_v = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) - do, _, _, descale_do = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + q, cu_seqlens_q, max_seqlen_q, descale_q = generate_varlen_tensor( + TOTAL_SEQLENS_Q, + HQ, + D_HEAD, + batch_size=BATCH, + dtype=dtype, + device=device, + equal_seqlens=equal_seqlens, + ) + k, cu_seqlens_k, max_seqlen_k, descale_k = generate_varlen_tensor( + TOTAL_SEQLENS_K, + HK, + D_HEAD, + batch_size=BATCH, + dtype=dtype, + device=device, + equal_seqlens=equal_seqlens, + ) + v, _, _, descale_v = generate_varlen_tensor( + TOTAL_SEQLENS_K, + HK, + D_HEAD, + batch_size=BATCH, + dtype=dtype, + device=device, + equal_seqlens=equal_seqlens, + ) + do, _, _, descale_do = generate_varlen_tensor( + TOTAL_SEQLENS_Q, + HQ, + D_HEAD, + batch_size=BATCH, + dtype=dtype, + device=device, + equal_seqlens=equal_seqlens, + ) else: - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) - k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) - v, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) - do, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor( + TOTAL_SEQLENS_Q, + HQ, + D_HEAD, + batch_size=BATCH, + dtype=dtype, + device=device, + equal_seqlens=equal_seqlens, + ) + k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor( + TOTAL_SEQLENS_K, + HK, + D_HEAD, + batch_size=BATCH, + dtype=dtype, + device=device, + equal_seqlens=equal_seqlens, + ) + v, _, _ = generate_varlen_tensor( + TOTAL_SEQLENS_K, + HK, + D_HEAD, + batch_size=BATCH, + dtype=dtype, + device=device, + equal_seqlens=equal_seqlens, + ) + do, _, _ = generate_varlen_tensor( + TOTAL_SEQLENS_Q, + HQ, + D_HEAD, + batch_size=BATCH, + dtype=dtype, + device=device, + equal_seqlens=equal_seqlens, + ) elif packing == "kv": # gen tensors with kv packing if is_fp8_dtype: raise ValueError("FP8 not supported for KV packing yet") else: - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) - kv, cu_seqlens_k, max_seqlen_k = generate_varlen_kv_packed(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) - do, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor( + TOTAL_SEQLENS_Q, + HQ, + D_HEAD, + batch_size=BATCH, + dtype=dtype, + device=device, + equal_seqlens=equal_seqlens, + ) + kv, cu_seqlens_k, max_seqlen_k = generate_varlen_kv_packed( + TOTAL_SEQLENS_K, + HK, + D_HEAD, + batch_size=BATCH, + dtype=dtype, + device=device, + equal_seqlens=equal_seqlens, + ) + do, _, _ = generate_varlen_tensor( + TOTAL_SEQLENS_Q, + HQ, + D_HEAD, + batch_size=BATCH, + dtype=dtype, + device=device, + equal_seqlens=equal_seqlens, + ) elif packing == "qkv": # qkv packing - requires same sequence length for q and k - assert N_CTX_Q == N_CTX_K, "For QKV packing, Q and K must have same sequence length" + assert ( + N_CTX_Q == N_CTX_K + ), "For QKV packing, Q and K must have same sequence length" assert HQ == HK, "For QKV packing, Q and K must have same number of heads" - + if is_fp8_dtype: raise ValueError("FP8 not supported for QKV packing yet") else: - qkv, cu_seqlens_q, max_seqlen_q = generate_varlen_qkv_packed(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + qkv, cu_seqlens_q, max_seqlen_q = generate_varlen_qkv_packed( + TOTAL_SEQLENS_Q, + HQ, + D_HEAD, + batch_size=BATCH, + dtype=dtype, + device=device, + equal_seqlens=equal_seqlens, + ) cu_seqlens_k = cu_seqlens_q max_seqlen_k = max_seqlen_q - do, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) - - # setup metadata - sm_scale = D_HEAD**-0.5 - metadata = MetaData(sm_scale=sm_scale) - metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) - metadata.need_causal(CAUSAL) - metadata.need_dropout(DROPOUT_P, True) - - elif layout == 'bshd' or layout == "bhsd": + do, _, _ = generate_varlen_tensor( + TOTAL_SEQLENS_Q, + HQ, + D_HEAD, + batch_size=BATCH, + dtype=dtype, + device=device, + equal_seqlens=equal_seqlens, + ) + + elif layout == "bshd" or layout == "bhsd": # deal with packing if packing is None: # gen tensors if layout == "bshd": if is_fp8_dtype: - q, descale_q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) - k, descale_k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device) - v, descale_v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device) - do, descale_do = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) + q, descale_q = generate_bshd_tensor( + BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device + ) + k, descale_k = generate_bshd_tensor( + BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device + ) + v, descale_v = generate_bshd_tensor( + BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device + ) + do, descale_do = generate_bshd_tensor( + BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device + ) else: - q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) - k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device) - v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device) - do = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) + q = generate_bshd_tensor( + BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device + ) + k = generate_bshd_tensor( + BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device + ) + v = generate_bshd_tensor( + BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device + ) + do = generate_bshd_tensor( + BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device + ) elif layout == "bhsd": - q, descale_q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) - k, descale_k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device) - v, descale_v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device) - do, descale_do = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) + q, descale_q = generate_bhsd_tensor( + BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device + ) + k, descale_k = generate_bhsd_tensor( + BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device + ) + v, descale_v = generate_bhsd_tensor( + BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device + ) + do, descale_do = generate_bhsd_tensor( + BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device + ) else: - q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) - k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device) - v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device) - do = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) + q = generate_bhsd_tensor( + BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device + ) + k = generate_bhsd_tensor( + BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device + ) + v = generate_bhsd_tensor( + BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device + ) + do = generate_bhsd_tensor( + BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device + ) elif packing == "kv": # gen tensors with kv packing if is_fp8_dtype: raise ValueError("FP8 not supported for KV packing yet") else: if layout == "bshd": - q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) - kv = generate_bshd_kv_packed(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device) - do = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) + q = generate_bshd_tensor( + BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device + ) + kv = generate_bshd_kv_packed( + BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device + ) + do = generate_bshd_tensor( + BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device + ) elif layout == "bhsd": - q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) - kv = generate_bhsd_kv_packed(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device) - do = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) + q = generate_bhsd_tensor( + BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device + ) + kv = generate_bhsd_kv_packed( + BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device + ) + do = generate_bhsd_tensor( + BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device + ) elif packing == "qkv": # qkv packing - requires same sequence length for q and k - assert N_CTX_Q == N_CTX_K, "For QKV packing, Q and K must have same sequence length" + assert ( + N_CTX_Q == N_CTX_K + ), "For QKV packing, Q and K must have same sequence length" assert HQ == HK, "For QKV packing, Q and K must have same number of heads" - + if is_fp8_dtype: raise ValueError("FP8 not supported for QKV packing yet") else: if layout == "bshd": - qkv = generate_bshd_qkv_packed(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) - do = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) + qkv = generate_bshd_qkv_packed( + BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device + ) + do = generate_bshd_tensor( + BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device + ) elif layout == "bhsd": - qkv = generate_bhsd_qkv_packed(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) - do = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) - - # setup metadata - sm_scale = D_HEAD**-0.5 - metadata = MetaData(sm_scale=sm_scale) - metadata.max_seqlens_q = N_CTX_Q - metadata.max_seqlens_k = N_CTX_K - metadata.layout = layout - metadata.need_causal(CAUSAL) - metadata.need_dropout(DROPOUT_P, True) + qkv = generate_bhsd_qkv_packed( + BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device + ) + do = generate_bhsd_tensor( + BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device + ) + else: raise ValueError(f"Unknown layout: {layout}") # return based on packing if packing is None: if is_fp8_dtype: - return (q, descale_q), (k, descale_k), (v, descale_v), (do, descale_do), metadata + return (q, descale_q), (k, descale_k), (v, descale_v), (do, descale_do) else: - return q, k, v, do, metadata + return q, k, v, do elif packing == "kv": if is_fp8_dtype: raise ValueError("FP8 not supported kv packing yet") else: - return q, kv, do, metadata + return q, kv, do elif packing == "qkv": if is_fp8_dtype: raise ValueError("FP8 not supported qkv packing yet") else: - return qkv, do, metadata + return qkv, do else: assert False, f"Unsupported packing mode: {packing}" + # ------------------------------- # Alibi # ------------------------------- @triton.jit -def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): +def compute_alibi_block( + alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False +): # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix # for casual mask we want something like this where (1 is kept and 0 is masked) # seqlen_q = 2 and seqlen_k = 5 @@ -717,11 +885,17 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo else: return alibi_block + # ------------------------------- # FP8 # ------------------------------- def is_dtype_fp8(dtype): - if dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}: + if dtype in { + torch.float8_e4m3fnuz, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + }: if arch_supports_fp8(): return True else: @@ -729,36 +903,50 @@ def is_dtype_fp8(dtype): else: return False + def is_fp8(x): return is_dtype_fp8(x.dtype) + @triton.jit def compute_fp8_scaling_factors(x, fp8_max: tl.constexpr): # compute fp8 scaling and descaling factor for a block - x_amax = tl.max(tl.abs(x)) # NOTE: abs deals with negative values + x_amax = tl.max(tl.abs(x)) # NOTE: abs deals with negative values x_amax = tl.where(x_amax <= 1e-9, 1e-9, x_amax) scale_x = fp8_max / x_amax descale_x = x_amax / fp8_max return scale_x, descale_x + @triton.jit def _cast_varlen_to_fp8_kernel_2d( - X, X_fp8, Descale, - cu_seqlens, H, MAX_SEQLEN, - stride_batch, stride_seq, stride_head, stride_dim, - stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim, - stride_desc_batch, stride_desc_head, - FP8_CLAMP_VAL, + X, + X_fp8, + Descale, + cu_seqlens, + H, + MAX_SEQLEN, + stride_batch, + stride_seq, + stride_head, + stride_dim, + stride_out_batch, + stride_out_seq, + stride_out_head, + stride_out_dim, + stride_desc_batch, + stride_desc_head, + FP8_CLAMP_VAL, FP8_MAX, BLOCK_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr, ACTUAL_HEAD_DIM: tl.constexpr, - IS_VARLEN: tl.constexpr - ): + IS_VARLEN: tl.constexpr, +): # Process one (batch, head) pair per kernel b_id = tl.program_id(0) h_id = tl.program_id(1) - + # Get sequence bounds for this batch if IS_VARLEN: seq_start = tl.load(cu_seqlens + b_id) @@ -766,11 +954,11 @@ def _cast_varlen_to_fp8_kernel_2d( seqlen = seq_end - seq_start else: seq_start = 0 - seqlen = MAX_SEQLEN - + seqlen = MAX_SEQLEN + # initialize max value tracker x_max_val = 0.0 - + # STEP 1: Find max absolute value across the entire sequence num_of_blocks = tl.cdiv(seqlen, BLOCK_SIZE) for blk_idx in range(0, num_of_blocks): @@ -778,7 +966,7 @@ def _cast_varlen_to_fp8_kernel_2d( # offsets offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offs_dim = tl.arange(0, HEAD_DIM) - + # Create mask for valid elements mask_seq = offs_seq[:, None] < seqlen if ACTUAL_HEAD_DIM != HEAD_DIM: @@ -786,27 +974,33 @@ def _cast_varlen_to_fp8_kernel_2d( mask_seq = mask_seq & mask_dim # Load block - adj_x = b_id * stride_batch + h_id * stride_head + seq_start * stride_seq + offs_seq[:, None] * stride_seq + offs_dim[None, :] * stride_dim + adj_x = ( + b_id * stride_batch + + h_id * stride_head + + seq_start * stride_seq + + offs_seq[:, None] * stride_seq + + offs_dim[None, :] * stride_dim + ) x_block = tl.load(X + adj_x, mask=mask_seq, other=0.0) # print("x_block:", x_block) - + # Find max absolute value in this block block_max = tl.max(tl.abs(x_block)) # print("block_max:", block_max) - + # Update overall max x_max_val = tl.maximum(x_max_val, block_max) # print("x_max_val:", x_max_val) - + # clamp to avoid division by zero issues x_max_val = tl.maximum(x_max_val, FP8_CLAMP_VAL) - + # compute scale and descale factors for the entire sequence scale = FP8_MAX / x_max_val descale = x_max_val / FP8_MAX - + # store descale factor for this (batch, head) pair - desc_ptr = Descale + b_id * stride_desc_batch + h_id# * stride_desc_head + desc_ptr = Descale + b_id * stride_desc_batch + h_id # * stride_desc_head tl.store(desc_ptr, descale) # STEP 2: Apply scaling to the entire sequence and convert to FP8 @@ -814,31 +1008,44 @@ def _cast_varlen_to_fp8_kernel_2d( # offsets offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offs_dim = tl.arange(0, HEAD_DIM) - + # Create mask for valid elements mask_seq = offs_seq[:, None] < seqlen if ACTUAL_HEAD_DIM != HEAD_DIM: mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM mask_seq = mask_seq & mask_dim - + # Load block - Using the fixed addressing - addr = b_id * stride_batch + h_id * stride_head + seq_start * stride_seq + offs_seq[:, None] * stride_seq + offs_dim[None, :] * stride_dim + addr = ( + b_id * stride_batch + + h_id * stride_head + + seq_start * stride_seq + + offs_seq[:, None] * stride_seq + + offs_dim[None, :] * stride_dim + ) x_block = tl.load(X + addr, mask=mask_seq, other=0.0) - + # Apply scale and convert to FP8 x_fp8_block = (x_block * scale).to(X_fp8.type.element_ty) - + # Store results - addr_out = b_id * stride_out_batch + h_id * stride_out_head + seq_start * stride_out_seq + offs_seq[:, None] * stride_out_seq + offs_dim[None, :] * stride_out_dim + addr_out = ( + b_id * stride_out_batch + + h_id * stride_out_head + + seq_start * stride_out_seq + + offs_seq[:, None] * stride_out_seq + + offs_dim[None, :] * stride_out_dim + ) tl.store(X_fp8 + addr_out, x_fp8_block, mask=mask_seq) + def cast_to_fp8( x: torch.Tensor, fp8_dtype: torch.dtype, layout: Literal["bshd", "thd"], clamp_val: float = 1e-9, cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None + max_seqlen: Optional[int] = None, ) -> tuple[torch.Tensor, torch.Tensor]: if False: print() @@ -850,10 +1057,17 @@ def cast_to_fp8( print("clamp_val:", clamp_val) # check types are valid - assert x.dtype in {torch.float16, torch.float32, torch.float64, torch.bfloat16} and is_dtype_fp8(fp8_dtype), f"Cannot cast {x.dtype} to {fp8_dtype}" + assert x.dtype in { + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + } and is_dtype_fp8(fp8_dtype), f"Cannot cast {x.dtype} to {fp8_dtype}" # extract dimensions - batch, max_seqlen_final, num_heads, head_dim = get_shape_from_layout(x, layout, cu_seqlens, max_seqlen) + batch, max_seqlen_final, num_heads, head_dim = get_shape_from_layout( + x, layout, cu_seqlens, max_seqlen + ) is_varlen = layout == "thd" fp8_max = torch.finfo(fp8_dtype).max if False: @@ -868,12 +1082,18 @@ def cast_to_fp8( # kernel params x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) - descale_factors = torch.zeros((batch, num_heads), device=x.device, dtype=torch.float32) + descale_factors = torch.zeros( + (batch, num_heads), device=x.device, dtype=torch.float32 + ) BLOCK_SIZE = 128 # calculate strides - stride_batch, stride_head, stride_seq, stride_dim = get_stride_from_layout(x, layout) - stride_out_batch, stride_out_head, stride_out_seq, stride_out_dim = get_stride_from_layout(x_fp8, layout) + stride_batch, stride_head, stride_seq, stride_dim = get_stride_from_layout( + x, layout + ) + stride_out_batch, stride_out_head, stride_out_seq, stride_out_dim = ( + get_stride_from_layout(x_fp8, layout) + ) stride_desc_batch, stride_desc_head = descale_factors.stride() if False: @@ -890,23 +1110,36 @@ def cast_to_fp8( grid = (batch, num_heads) _cast_varlen_to_fp8_kernel_2d[grid]( - x, x_fp8, descale_factors, - cu_seqlens, num_heads, max_seqlen_final, - stride_batch, stride_seq, stride_head, stride_dim, - stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim, - stride_desc_batch, stride_desc_head, - clamp_val, fp8_max, + x, + x_fp8, + descale_factors, + cu_seqlens, + num_heads, + max_seqlen_final, + stride_batch, + stride_seq, + stride_head, + stride_dim, + stride_out_batch, + stride_out_seq, + stride_out_head, + stride_out_dim, + stride_desc_batch, + stride_desc_head, + clamp_val, + fp8_max, BLOCK_SIZE=BLOCK_SIZE, - HEAD_DIM=padded_head_dim, + HEAD_DIM=padded_head_dim, ACTUAL_HEAD_DIM=head_dim, - IS_VARLEN=is_varlen + IS_VARLEN=is_varlen, ) - + if False: print("x_fp8:", x_fp8, x_fp8.shape) print("descale_factors:", descale_factors, descale_factors.shape) return x_fp8, descale_factors + # ------------------------------- # Misc # ------------------------------- @@ -916,47 +1149,74 @@ def get_shape_from_layout( cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, ) -> tuple[int, int, int, int]: - if layout == 'bhsd': + if layout == "bhsd": batch, num_heads, max_seqlen_final, head_dim = x.shape - elif layout == 'bshd': + elif layout == "bshd": batch, max_seqlen_final, num_heads, head_dim = x.shape - elif layout == 'thd': + elif layout == "thd": total_seqlen, num_heads, head_dim = x.shape if cu_seqlens is None: - raise ValueError("cu_seqlens must be provided for varlen (thd) layout") + raise ValueError("cu_seqlens must be provided for varlen (thd) layout") if max_seqlen is None: raise ValueError("max_seqlen must be provided for varlen (thd) layout") - - batch, max_seqlen_final, num_heads, head_dim = len(cu_seqlens) - 1, max_seqlen, num_heads, head_dim + + batch, max_seqlen_final, num_heads, head_dim = ( + len(cu_seqlens) - 1, + max_seqlen, + num_heads, + head_dim, + ) else: assert False, "Got unsupported layout." return batch, max_seqlen_final, num_heads, head_dim -def get_shapes_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None): - batch_q, seqlen_q, nheads_q, head_size_q = get_shape_from_layout(q, layout, cu_seqlens_q, max_seqlen_q) - batch_k, seqlen_k, nheads_k, head_size_k = get_shape_from_layout(k, layout, cu_seqlens_k, max_seqlen_k) - +def get_shapes_from_layout( + q, + k, + layout, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, +): + batch_q, seqlen_q, nheads_q, head_size_q = get_shape_from_layout( + q, layout, cu_seqlens_q, max_seqlen_q + ) + batch_k, seqlen_k, nheads_k, head_size_k = get_shape_from_layout( + k, layout, cu_seqlens_k, max_seqlen_k + ) + # assert assert batch_q == batch_k assert head_size_q == head_size_k return batch_q, nheads_q, nheads_k, head_size_q, seqlen_q, seqlen_k -def get_stride_from_layout(x: torch.Tensor, layout:Literal["bshd", "bhsd", "thd"]): - if layout == 'thd': - strides = (0, x.stride(1), x.stride(0), x.stride(2)) - elif layout == 'bhsd': + +def get_stride_from_layout(x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"]): + if layout == "thd": + strides = (0, x.stride(1), x.stride(0), x.stride(2)) + elif layout == "bhsd": strides = (x.stride(0), x.stride(1), x.stride(2), x.stride(3)) - elif layout == 'bshd': + elif layout == "bshd": strides = (x.stride(0), x.stride(2), x.stride(1), x.stride(3)) else: - assert False, 'Got unsupported layout.' + assert False, "Got unsupported layout." return strides -def get_shape_and_strides_from_layout(x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"], cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None): - return get_shape_from_layout(x, layout, cu_seqlens, max_seqlen), get_stride_from_layout(x, layout) + +def get_shape_and_strides_from_layout( + x: torch.Tensor, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + return get_shape_from_layout( + x, layout, cu_seqlens, max_seqlen + ), get_stride_from_layout(x, layout) + def get_strides_from_layout(q, k, v, o, layout): q_strides = get_stride_from_layout(q, layout) @@ -965,6 +1225,7 @@ def get_strides_from_layout(q, k, v, o, layout): o_strides = get_stride_from_layout(o, layout) return q_strides, k_strides, v_strides, o_strides + def get_padded_headsize(size): # Get closest power of 2 over or equal to 32. padded_d_model = 1 << (size - 1).bit_length() @@ -973,19 +1234,28 @@ def get_padded_headsize(size): padded_d_model = max(padded_d_model, 16) return padded_d_model + def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k): - q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) - k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) + q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze( + -1 + ) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze( + 0 + ) # (1, N_CTX_K) relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) - return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) + return ( + -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos + ) # (Z, H, N_CTX_Q, N_CTX_K) + def round_multiple(x, m): return (x + m - 1) // m * m + def save_tensor_to_csv(tensor, filename, decimal_places=2): """ save a 2d tensor to csv file - + args: tensor: torch tensor of shape [rows, cols] filename: output csv filename @@ -994,46 +1264,63 @@ def save_tensor_to_csv(tensor, filename, decimal_places=2): # ensure tensor is 2d if tensor.ndim != 2: raise ValueError(f"tensor must be 2d, got shape {tensor.shape}") - + # ensure filename ends with .csv - if not filename.endswith('.csv'): - filename = filename + '.csv' - + if not filename.endswith(".csv"): + filename = filename + ".csv" + # save to csv using numpy - np.savetxt(filename, - tensor.detach().cpu().numpy(), - delimiter=',', - fmt=f'%.{decimal_places}f') + np.savetxt( + filename, + tensor.detach().cpu().numpy(), + delimiter=",", + fmt=f"%.{decimal_places}f", + ) + # ------------------------------- # Dropouts # ------------------------------- def create_dropout_mask(dropout_p, shape, seed): device = "cuda" - rand_vals = torch.rand(shape, generator=torch.Generator(device=device).manual_seed(seed), device=device, dtype=torch.float32) + rand_vals = torch.rand( + shape, + generator=torch.Generator(device=device).manual_seed(seed), + device=device, + dtype=torch.float32, + ) return rand_vals > dropout_p -def create_dropout_mask_varlen(dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed): + +def create_dropout_mask_varlen( + dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed +): device = "cuda" - qlens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) - klens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) + qlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + klens = cu_seqlens_k[1:] - cu_seqlens_k[:-1] max_qlen = qlens.max() max_klen = klens.max() dropout_mask = torch.zeros((batch, nheads_q, max_qlen, max_klen), device=device) for b in range(batch): qlen = qlens[b] klen = klens[b] - rand_vals = torch.rand((nheads_q, qlen, klen), generator=torch.Generator(device=device).manual_seed(philox_seed), device=device, dtype=torch.float32) + rand_vals = torch.rand( + (nheads_q, qlen, klen), + generator=torch.Generator(device=device).manual_seed(philox_seed), + device=device, + dtype=torch.float32, + ) submask = rand_vals > dropout_p dropout_mask[b, :, :qlen, :klen] = submask return dropout_mask -def write_dropout_mask(x, tensor_name = "tensor"): + +def write_dropout_mask(x, tensor_name="tensor"): batch, head, seqlen_m, seqlen_n = x.shape x = x.tolist() - with open(f'{tensor_name}.csv', 'w') as f: + with open(f"{tensor_name}.csv", "w") as f: writer = csv.writer(f) for b in range(batch): for h in range(head): @@ -1041,22 +1328,22 @@ def write_dropout_mask(x, tensor_name = "tensor"): if True: BLOCK_M = 64 BLOCK_N = 64 - + # Calculate number of blocks in each dimension m_blocks = math.ceil(seqlen_m / BLOCK_M) n_blocks = math.ceil(seqlen_n / BLOCK_N) - + # Process each block for m_block in range(m_blocks): # Calculate row range for current block row_start = m_block * BLOCK_M row_end = min(row_start + BLOCK_M, seqlen_m) - + for n_block in range(n_blocks): # Calculate column range for current block col_start = n_block * BLOCK_N col_end = min(col_start + BLOCK_N, seqlen_n) - + # Extract and write the current block for row_idx in range(row_start, row_end): row_data = dropout_mask[row_idx][col_start:col_end] @@ -1064,6 +1351,379 @@ def write_dropout_mask(x, tensor_name = "tensor"): else: writer.writerows(dropout_mask) + +# ------------------------------- +# Rotary +# ------------------------------- +@triton.jit +def _rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, + seqlen, + nheads, + seqlen_ro, + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + ROTARY_DIM: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_M: tl.constexpr, +): + BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM) + ROTARY_DIM_HALF = ROTARY_DIM // 2 + pid_head = tl.program_id(axis=0) + pid_m = tl.program_id(axis=1) + pid_batch = tl.program_id(axis=2) + + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + OUT = OUT + pid_batch * stride_out_batch + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + OUT = OUT + start_idx * stride_out_seqlen + + if pid_m * BLOCK_M >= seqlen: + return + + rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + + rk_half = tl.arange(0, BLOCK_K // 2) + COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF) + cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32) + sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + + if not INTERLEAVED: + X = X + ( + rh[:, None, None] * stride_x_nheads + + rm[None, :, None] * stride_x_seqlen + + rk_half[None, None, :] * stride_x_headdim + ) + OUT = OUT + ( + rh[:, None, None] * stride_out_nheads + + rm[None, :, None] * stride_out_seqlen + + rk_half[None, None, :] * stride_out_headdim + ) + mask = ( + (rh[:, None, None] < nheads) + & (rm[None, :, None] < seqlen) + & (rk_half[None, None, :] < ROTARY_DIM_HALF) + ) + x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0).to( + tl.float32 + ) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + tl.store(OUT, o0, mask=mask) + tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask) + else: + rk = tl.arange(0, BLOCK_K) + X = X + ( + rh[:, None, None] * stride_x_nheads + + rm[None, :, None] * stride_x_seqlen + + rk[None, None, :] * stride_x_headdim + ) + OUT = OUT + ( + rh[:, None, None] * stride_out_nheads + + rm[None, :, None] * stride_out_seqlen + + rk[None, None, :] * stride_out_headdim + ) + mask = ( + (rh[:, None, None] < nheads) + & (rm[None, :, None] < seqlen) + & (rk[None, None, :] < ROTARY_DIM) + ) + x = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2])) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K]) + tl.store(OUT, o, mask=mask) + + +def _apply_rotary_kernel( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert ( + max_seqlen is not None + ), "If cu_seqlens is passed, max_seqlen must also be provided" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim_half = cos.shape + assert sin.shape == cos.shape + rotary_dim = 2 * rotary_dim_half + assert rotary_dim <= headdim + assert headdim <= 256 + assert seqlen_ro >= seqlen + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in (torch.int32, torch.int64) + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + out = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + # Block heuristics + BLOCK_M = 8 if rotary_dim <= 128 else 4 + grid = ( + triton.cdiv(nheads, 2), + triton.cdiv(seqlen, BLOCK_M), + batch, + ) + + # NOTE: We assume CUDA device indexing compatibility in upstream; adapt for ROCm by using device context. + # For ROCm, torch.cuda.device works if HIP_VISIBLE_DEVICES mapping is set. + with torch.cuda.device(x.device.index): # Works for ROCm as alias + torch.library.wrap_triton(_rotary_kernel)[grid]( + out, + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, + nheads, + seqlen_ro, + out.stride(0) if not is_varlen else 0, + out.stride(-3), + out.stride(-2), + out.stride(-1), + x.stride(0) if not is_varlen else 0, + x.stride(-3), + x.stride(-2), + x.stride(-1), + rotary_dim, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M=BLOCK_M, + BLOCK_H=2, + ) + return out + + +class _ApplyRotary(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool, + inplace: bool, + seqlen_offsets: Union[int, torch.Tensor], + cu_seqlens: Optional[torch.Tensor], + max_seqlen: Optional[int], + ): + out = _apply_rotary_kernel( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + conjugate=False, + ) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do: torch.Tensor): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + dx = _apply_rotary_kernel( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False, + inplace: bool = False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +) -> torch.Tensor: + """Public API: apply rotary embeddings to tensor x. + + Args: + x: (B, S, H, D) if `cu_seqlens` is None else (total_S, H, D). + cos, sin: (S_rotary, rotary_dim/2) + interleaved: GPT-J style if True. + inplace: modify x in place (saves memory if rotary_dim == D). + seqlen_offsets: int or (B,) tensor of starting offsets per sequence (KV cache decode). + cu_seqlens: (B+1,) tensor enabling varlen mode. + max_seqlen: required when `cu_seqlens` is provided. + """ + # FP8 path: upcast to bfloat16 (preferred) or float16 for rotary math to avoid excessive error + original_dtype = x.dtype + is_fp8_input = original_dtype == getattr(torch, "float8_e4m3fn", None) + if is_fp8_input: + # Choose bf16 if available in cos.dtype path; otherwise fallback to float16 + target_dtype = ( + torch.bfloat16 + if cos.dtype == torch.bfloat16 or torch.cuda.is_bf16_supported() + else torch.float16 + ) + # Upcast x, cos, sin for computation (without modifying originals in-place) + x_up = x.to(target_dtype) + cos_up = cos.to(target_dtype) if cos.dtype != target_dtype else cos + sin_up = sin.to(target_dtype) if sin.dtype != target_dtype else sin + out_up = _ApplyRotary.apply( + x_up, + cos_up, + sin_up, + interleaved, + False, + seqlen_offsets, + cu_seqlens, + max_seqlen, + ) + # Cast result back to original fp8 dtype + if inplace: + x.copy_(out_up.to(original_dtype)) + return x + return out_up.to(original_dtype) + else: + return _ApplyRotary.apply( + x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen + ) + + +def apply_rotary( + q: torch.Tensor, + k_new: Optional[torch.Tensor], + cos: torch.Tensor, + sin: torch.Tensor, + *, + causal: bool, + local: bool, + interleaved: bool = False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """High-level rotary application used by AMD prefill & decode paths. + + Policy (matches test reference & legacy semantics): + - If causal OR local attention ⇒ apply rotary directly on (B, S, H, D). + - Else (non-causal global) ⇒ flatten heads into sequence: (B, 1, S*H, D), + apply rotary once, then unflatten back. + - k_new (incremental KV slice) is always rotated directly when provided. + + Args: + q: (B, S, H, D) + k_new: Optional (B, S_k, H_k, D) + cos, sin: rotary caches (S_rotary, rotary_dim/2) + causal: causal attention flag + local: sliding-window / local attention flag (pre-computed outside) + interleaved: GPT-J style rotary layout + seqlen_offsets: int or (B,) tensor of per-sequence start offsets + Returns: + (q_rot, k_new_rot) + """ + assert q.ndim == 4, f"Expected q shape (B,S,H,D), got {q.shape}" + B, S, H, D = q.shape + use_flatten = (not causal) and (not local) + + if use_flatten: + # Flatten (S,H) -> (S*H) with an added singleton dim to preserve expected 4D shape. + q_flat = q.reshape(B, S * H, D).unsqueeze(1) # (B, 1, S*H, D) + q_flat = apply_rotary_emb( + q_flat, + cos, + sin, + interleaved=interleaved, + seqlen_offsets=seqlen_offsets, + ) + # Restore shape back to (B, S, H, D) + q = q_flat.view(B, 1, S * H, D).reshape(B, S, H, D) + else: + q = apply_rotary_emb( + q, + cos, + sin, + interleaved=interleaved, + seqlen_offsets=seqlen_offsets, + ) + + if k_new is not None: + k_new = apply_rotary_emb( + k_new, + cos, + sin, + interleaved=interleaved, + seqlen_offsets=seqlen_offsets, + ) + return q, k_new + + # ------------------------------- # Runtime info # ------------------------------- @@ -1071,18 +1731,36 @@ def write_dropout_mask(x, tensor_name = "tensor"): def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" + @functools.cache def get_arch(): return triton.runtime.driver.active.get_current_target().arch + @functools.cache def is_cdna(): - return is_hip() and get_arch() in ('gfx908', 'gfx90a', 'gfx940', 'gfx941', 'gfx942', 'gfx950') + return is_hip() and get_arch() in ( + "gfx908", + "gfx90a", + "gfx940", + "gfx941", + "gfx942", + "gfx950", + ) + @functools.cache def is_rdna(): - return is_hip() and get_arch() in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201") + return is_hip() and get_arch() in ( + "gfx1030", + "gfx1100", + "gfx1101", + "gfx1102", + "gfx1200", + "gfx1201", + ) + @functools.cache def arch_supports_fp8(): - return is_hip() and get_arch() in ('gfx942') + return is_hip() and get_arch() in ("gfx942") diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 8ddfa37bd31..4f54e44b95c 100755 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -11,7 +11,7 @@ if USE_TRITON_ROCM: import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - from flash_attn.flash_attn_triton_amd import interface_fa_v3 as flash_attn_3_gpu + from flash_attn.flash_attn_triton_amd import flash_attn_3 as flash_attn_3_gpu else: # isort: off # We need to import the CUDA kernels after importing torch diff --git a/hopper/test_flash_attn_triton_amd.py b/hopper/test_flash_attn_triton_amd.py index 738ec1d8c13..73e54dce066 100755 --- a/hopper/test_flash_attn_triton_amd.py +++ b/hopper/test_flash_attn_triton_amd.py @@ -27,7 +27,7 @@ DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "TRUE") == "TRUE" DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" -DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "TRUE") == "TRUE" +DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "TRUE") == "TRUE" DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "TRUE") == "TRUE" DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "TRUE") == "TRUE" @@ -579,9 +579,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) +@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) # @pytest.mark.parametrize("new_kv", [True]) -@pytest.mark.parametrize("new_kv", [False, True]) @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) # @pytest.mark.parametrize("causal,local", [(False, False)]) From 9f807ac46c093ccef912e69994dcf6e890789d97 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 30 Sep 2025 10:59:24 -0400 Subject: [PATCH 14/27] add bwd_change (#156) --- .gitignore | 1 + flash_attn/flash_attn_triton_amd/bwd.py | 379 ++++++++---------------- 2 files changed, 121 insertions(+), 259 deletions(-) diff --git a/.gitignore b/.gitignore index 96d9af07e82..c4a123e4e0b 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,7 @@ scripts csrc/flash_attn_ck .eggs log +*.rocprof* *.log core.* gpucore.* diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 085232cedc5..4d4c22866d6 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -19,286 +19,147 @@ tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) -def get_autotune_configs(): - if False: - if is_cdna(): - # shared meta-parameters - NUM_STAGES = 1 - NUM_WARPS = 4 - WAVES_PER_EU = 2 - MATRIX_INSTR_NONKDIM = 16 - - preprocess_autotune_configs = [ - triton.Config( - { - "PRE_BLOCK": 128, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), # og config - triton.Config( - { - "PRE_BLOCK": 64, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - triton.Config( - { - "PRE_BLOCK": 32, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - triton.Config( - { - "PRE_BLOCK": 16, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - ] - preprocess_autotune_keys = [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM_QK", - "ACTUAL_HEAD_DIM_V", - "IS_VARLEN", - "HQ", - "HK", - ] - causal_autotune_configs = [ - triton.Config( - { - "BLOCK_M1": 32, - "BLOCK_N1": 128, - "BLOCK_M2": 128, - "BLOCK_N2": 32, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), # og config - triton.Config( - { - "BLOCK_M1": 16, - "BLOCK_N1": 128, - "BLOCK_M2": 128, - "BLOCK_N2": 16, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - triton.Config( - { - "BLOCK_M1": 16, - "BLOCK_N1": 64, - "BLOCK_M2": 64, - "BLOCK_N2": 16, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - triton.Config( - { - "BLOCK_M1": 32, - "BLOCK_N1": 64, - "BLOCK_M2": 64, - "BLOCK_N2": 32, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - ] - causal_autotune_keys = [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM_QK", - "ACTUAL_HEAD_DIM_V", - "IS_VARLEN", - "HQ", - "HK", - ] - noncausal_autotune_configs = [ - triton.Config( - { - "BLOCK_M1": 32, - "BLOCK_N1": 128, - "BLOCK_M2": 128, - "BLOCK_N2": 32, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), # og config - triton.Config( - { - "BLOCK_M1": 16, - "BLOCK_N1": 128, - "BLOCK_M2": 128, - "BLOCK_N2": 16, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - triton.Config( - { - "BLOCK_M1": 16, - "BLOCK_N1": 64, - "BLOCK_M2": 64, - "BLOCK_N2": 16, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - triton.Config( - { - "BLOCK_M1": 32, - "BLOCK_N1": 64, - "BLOCK_M2": 64, - "BLOCK_N2": 32, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - ] - noncausal_autotune_keys = [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM_QK", - "ACTUAL_HEAD_DIM_V", - "IS_VARLEN", - "HQ", - "HK", - ] - - return ( - (preprocess_autotune_configs, preprocess_autotune_keys), - (causal_autotune_configs, causal_autotune_keys), - (noncausal_autotune_configs, noncausal_autotune_keys), - ) - else: - raise ValueError("Unknown Device Type") - else: - # meta-parameters - # TODO: fix num_stages later - NUM_WARPS, NUM_STAGES = 4, 1 - WAVES_PER_EU = 1 - PRE_BLOCK = 128 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLK_SLICE_FACTOR = 2 - - assert BLOCK_N1 == BLOCK_M2 +def get_bwd_configs(autotune = False): + # default config + if not autotune: + # preprocess params + PRE_BLOCK = 64 + PRE_WAVES_PER_EU=2 + PRE_NUM_STAGES=2 + PRE_NUM_WARPS=8 # configs for the kernels preprocess_autotune_configs = [ - triton.Config( - {"PRE_BLOCK": PRE_BLOCK, "waves_per_eu": WAVES_PER_EU}, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), + triton.Config({"PRE_BLOCK": PRE_BLOCK, "waves_per_eu": PRE_WAVES_PER_EU}, num_stages=PRE_NUM_STAGES, num_warps=PRE_NUM_WARPS), ] preprocess_autotune_keys = [ "max_seqlen_q", - "ACTUAL_HEAD_DIM_V", - "IS_VARLEN", + "ACTUAL_HEAD_DIM", "IS_VARLEN", ] + + # main params + NUM_STAGES=1 + NUM_WARPS= 4 + WAVES_PER_EU = 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 64 + BLK_SLICE_FACTOR = 2 + MATRIX_INSTR_NONKDIM=16 + assert BLOCK_N1 == BLOCK_M2 + causal_autotune_configs = [ - triton.Config( - { - "BLOCK_M1": BLOCK_M1, - "BLOCK_N1": BLOCK_N1, - "BLOCK_M2": BLOCK_M2, - "BLOCK_N2": BLOCK_N2, - "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, - "waves_per_eu": WAVES_PER_EU, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), + triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), ] causal_autotune_keys = [ - "dropout_p", - "max_seqlen_q", - "max_seqlen_k", - "ACTUAL_HEAD_DIM_QK", - "ACTUAL_HEAD_DIM_V", - "IS_VARLEN", - "HQ", - "HK", + "dropout_p", "max_seqlen_q", "max_seqlen_k", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", ] noncausal_autotune_configs = [ - triton.Config( - { - "BLOCK_M1": BLOCK_M1, - "BLOCK_N1": BLOCK_N1, - "BLOCK_M2": BLOCK_M2, - "BLOCK_N2": BLOCK_N2, - "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, - "waves_per_eu": WAVES_PER_EU, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), + triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), ] noncausal_autotune_keys = [ - "dropout_p", - "max_seqlen_q", - "max_seqlen_k", - "ACTUAL_HEAD_DIM_QK", - "ACTUAL_HEAD_DIM_V", - "IS_VARLEN", - "HQ", - "HK", + "dropout_p", "max_seqlen_q", "max_seqlen_k", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", ] - return ( - (preprocess_autotune_configs, preprocess_autotune_keys), - (causal_autotune_configs, causal_autotune_keys), - (noncausal_autotune_configs, noncausal_autotune_keys), - ) + return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) + + + # params + PRE_BLOCK_OPTIONS = [64, 128] # og: 128 + PRE_WAVES_PER_EU_OPTIONS=[1, 2] + PRE_NUM_STAGES_OPTIONS=[1, 2] + PRE_NUM_WARPS_OPTIONS=[4, 8] + + + # Preprocess configs + preprocess_autotune_configs = [] + for pre_num_warps in PRE_NUM_WARPS_OPTIONS: + for pre_num_stages in PRE_NUM_STAGES_OPTIONS: + for pre_waves in PRE_WAVES_PER_EU_OPTIONS: + for pre_block in PRE_BLOCK_OPTIONS: + preprocess_autotune_configs.append( + triton.Config({ + "PRE_BLOCK": pre_block, + "waves_per_eu": pre_waves, + }, num_stages=pre_num_stages, num_warps=pre_num_warps) + ) + + NUM_STAGES_OPTIONS = [1, 2] # og: 1 + NUM_WARPS_OPTIONS = [4, 8] # og: 4 + WAVES_PER_EU_OPTIONS = [1, 2] # og: 1 + MATRIX_INSTR_NONKDIM_OPTIONS = [16, 32] # og: 16 + BLOCK_M1_OPTIONS = [ # og: 32 + 32, 64 + ] + BLOCK_N1_M2_OPTIONS = [ # og: 128 + 64, 128 + ] + BLOCK_N2_OPTIONS = [ # og: 32 + 32, 64 + ] + BLK_SLICE_FACTOR_OPTIONS = [2] # og: 2 + + # build configs + causal_autotune_configs = [] + noncausal_autotune_configs = [] + for num_warps in NUM_WARPS_OPTIONS: + for num_stages in NUM_STAGES_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for matrix_instr_nonkdim in MATRIX_INSTR_NONKDIM_OPTIONS: + # Causal and non-causal configs + for m1 in BLOCK_M1_OPTIONS: + for n1 in BLOCK_N1_M2_OPTIONS: + m2 = n1 + for n2 in BLOCK_N2_OPTIONS: + # Ensure constraint + assert n1 == m2, f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + + for blk_slice in BLK_SLICE_FACTOR_OPTIONS: + causal_autotune_configs.append( + triton.Config({ + "BLOCK_M1": m1, "BLOCK_N1": n1, + "BLOCK_M2": m2, "BLOCK_N2": n2, + "BLK_SLICE_FACTOR": blk_slice, + "waves_per_eu": waves, + "matrix_instr_nonkdim": matrix_instr_nonkdim + }, num_stages=num_stages, num_warps=num_warps) + ) + + noncausal_autotune_configs.append( + triton.Config({ + "BLOCK_M1": m1, "BLOCK_N1": n1, + "BLOCK_M2": m2, "BLOCK_N2": n2, + "BLK_SLICE_FACTOR": blk_slice, + "waves_per_eu": waves, + "matrix_instr_nonkdim": matrix_instr_nonkdim + }, num_stages=num_stages, num_warps=num_warps) + ) + + # kernel keys + preprocess_autotune_keys = [ + "max_seqlen_q", + "ACTUAL_HEAD_DIM", "IS_VARLEN", + ] + + causal_autotune_keys = [ + "dropout_p", "max_seqlen_q", "max_seqlen_k", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + + noncausal_autotune_keys = [ + "dropout_p", "max_seqlen_q", "max_seqlen_k", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + + return (preprocess_autotune_configs, preprocess_autotune_keys), \ + (causal_autotune_configs, causal_autotune_keys), \ + (noncausal_autotune_configs, noncausal_autotune_keys) ( (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys), -) = get_autotune_configs() +) = get_bwd_configs() # This function computes delta given output Out and gradient DO From 3499ee97f2736fdf3f22ff4055b3f95ecd814811 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Mon, 13 Oct 2025 07:10:21 -0400 Subject: [PATCH 15/27] Tune FP8 Perf (#160) * check cu count for gfx942 * create get_cu_count * update repo root * update forward tune * clean up load * use float8_e4m3fnuz * save * show bwd mode * recommend fp8 * use torch.float32 for fp8 kernel * add both best fp16 and fp8 config * tune fp8 backward * descale factors should be b, hk * fp8 bwd working on all primus configs * tune bwd configs * fa v3 tests passing * better warning * clean up bwd launcher * v3 passing * tune more * improve perf * clean up * lint * clean * start tuning gfx950 * tune non causal path * fix bug * save * Skip configs where BLOCK_M2 % BLOCK_N2 != 0 * skip more * stop tuning * fix varlen bug * fix dropout & causal/swa segfault --- flash_attn/flash_attn_triton_amd/bwd.py | 2278 ++++++++--------- .../flash_attn_triton_amd/fwd_decode.py | 69 +- .../flash_attn_triton_amd/fwd_prefill.py | 650 ++--- .../flash_attn_triton_amd/interface_v2.py | 253 +- .../flash_attn_triton_amd/interface_v3.py | 300 ++- flash_attn/flash_attn_triton_amd/utils.py | 397 +-- hopper/flash_attn_interface.py | 9 +- 7 files changed, 2030 insertions(+), 1926 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 4d4c22866d6..d2ed7aa113a 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -2,252 +2,504 @@ import torch import triton # type: ignore import triton.language as tl # type: ignore +import warnings from typing import Literal, Optional from .utils import ( DEBUG, - DROPOUT_USE_PYTORCH, - DROPOUT_DUMP, + AUTOTUNE, compute_fp8_scaling_factors, - create_dropout_mask, - create_dropout_mask_varlen, + get_cu_count, is_cdna, is_fp8, + get_arch, ) -# NOTE: triton fails to import tl.constexprs so create them here for the file -tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) -tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) +def get_bwd_configs(autotune: bool): + # keys + preprocess_autotune_keys = [ + "max_seqlen_q", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", + ] + + causal_autotune_keys = [ + "dropout_p", + "max_seqlen_q", + "max_seqlen_k", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", + "HQ", + "HK", + ] + noncausal_autotune_keys = [ + "dropout_p", + "max_seqlen_q", + "max_seqlen_k", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", + "HQ", + "HK", + ] -def get_bwd_configs(autotune = False): # default config if not autotune: - # preprocess params - PRE_BLOCK = 64 - PRE_WAVES_PER_EU=2 - PRE_NUM_STAGES=2 - PRE_NUM_WARPS=8 - + arch = get_arch() # configs for the kernels - preprocess_autotune_configs = [ - triton.Config({"PRE_BLOCK": PRE_BLOCK, "waves_per_eu": PRE_WAVES_PER_EU}, num_stages=PRE_NUM_STAGES, num_warps=PRE_NUM_WARPS), - ] - preprocess_autotune_keys = [ - "max_seqlen_q", - "ACTUAL_HEAD_DIM", "IS_VARLEN", - ] - - # main params - NUM_STAGES=1 - NUM_WARPS= 4 - WAVES_PER_EU = 1 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 64 - BLK_SLICE_FACTOR = 2 - MATRIX_INSTR_NONKDIM=16 - assert BLOCK_N1 == BLOCK_M2 - - causal_autotune_configs = [ - triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - causal_autotune_keys = [ - "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] - noncausal_autotune_configs = [ - triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - noncausal_autotune_keys = [ - "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] - return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) - - - # params - PRE_BLOCK_OPTIONS = [64, 128] # og: 128 - PRE_WAVES_PER_EU_OPTIONS=[1, 2] - PRE_NUM_STAGES_OPTIONS=[1, 2] - PRE_NUM_WARPS_OPTIONS=[4, 8] - - - # Preprocess configs + if arch == "gfx942": + if get_cu_count() < 304: + preprocess_autotune_configs = [ + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 128, "waves_per_eu": 2}, num_stages=1, num_warps=4 + ), + ] + noncausal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=8, + ), + ] + causal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + ] + else: + preprocess_autotune_configs = [ + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=4 + ), + ] + noncausal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + ] + causal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + ] + elif arch == "gfx950": + preprocess_autotune_configs = [ + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=1, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=4 + ), + ] + noncausal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 128, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 16, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 2, + }, + num_stages=1, + num_warps=4, + ), + ] + causal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] + else: + preprocess_autotune_configs = [ + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + ] + noncausal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] + causal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] + + # assert constraints + for noncausal_cfg, causal_cfg in zip( + noncausal_autotune_configs, causal_autotune_configs + ): + assert ( + noncausal_cfg.all_kwargs()["BLOCK_N1"] + == noncausal_cfg.all_kwargs()["BLOCK_M2"] + ), f"BLOCK_N1 ({noncausal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({noncausal_cfg.all_kwargs()['BLOCK_M2']})" + assert ( + causal_cfg.all_kwargs()["BLOCK_N1"] + == causal_cfg.all_kwargs()["BLOCK_M2"] + ), f"BLOCK_N1 ({causal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({causal_cfg.all_kwargs()['BLOCK_M2']})" + + return ( + (preprocess_autotune_configs, preprocess_autotune_keys), + (causal_autotune_configs, causal_autotune_keys), + (noncausal_autotune_configs, noncausal_autotune_keys), + ) + + # param options + PRE_BLOCK_OPTIONS = [64, 128] # og: 128 + PRE_WAVES_PER_EU_OPTIONS = [1, 2] + PRE_NUM_STAGES_OPTIONS = [1, 2] + PRE_NUM_WARPS_OPTIONS = [4, 8] + NUM_STAGES_OPTIONS = [1, 2] # og: 1 + NUM_WARPS_OPTIONS = [4, 8] # og: 4 + WAVES_PER_EU_OPTIONS = [1, 2] # og: 1 + NON_CAUSAL_BLOCK_M1_OPTIONS = [16, 32, 64, 128] # og: 32 + NON_CAUSAL_BLOCK_N1_M2_OPTIONS = [32, 64, 128, 256] # og: 128 + NON_CAUSAL_BLOCK_N2_OPTIONS = [16, 32, 64, 128] # og: 32 + CAUSAL_BLOCK_M1_OPTIONS = [ # og: 32 + 32, + 64 + ] + CAUSAL_BLOCK_N1_M2_OPTIONS = [32, 64, 128] # og: 128 + CAUSAL_BLOCK_N2_OPTIONS = [32, 64] # og: 32 + BLK_SLICE_FACTOR_OPTIONS = [2] # og: 2 + + # ==================== sweep configs ================================ preprocess_autotune_configs = [] for pre_num_warps in PRE_NUM_WARPS_OPTIONS: for pre_num_stages in PRE_NUM_STAGES_OPTIONS: for pre_waves in PRE_WAVES_PER_EU_OPTIONS: for pre_block in PRE_BLOCK_OPTIONS: preprocess_autotune_configs.append( - triton.Config({ - "PRE_BLOCK": pre_block, - "waves_per_eu": pre_waves, - }, num_stages=pre_num_stages, num_warps=pre_num_warps) + triton.Config( + { + "PRE_BLOCK": pre_block, + "waves_per_eu": pre_waves, + }, + num_stages=pre_num_stages, + num_warps=pre_num_warps, + ) ) - NUM_STAGES_OPTIONS = [1, 2] # og: 1 - NUM_WARPS_OPTIONS = [4, 8] # og: 4 - WAVES_PER_EU_OPTIONS = [1, 2] # og: 1 - MATRIX_INSTR_NONKDIM_OPTIONS = [16, 32] # og: 16 - BLOCK_M1_OPTIONS = [ # og: 32 - 32, 64 - ] - BLOCK_N1_M2_OPTIONS = [ # og: 128 - 64, 128 - ] - BLOCK_N2_OPTIONS = [ # og: 32 - 32, 64 - ] - BLK_SLICE_FACTOR_OPTIONS = [2] # og: 2 - - # build configs causal_autotune_configs = [] - noncausal_autotune_configs = [] for num_warps in NUM_WARPS_OPTIONS: for num_stages in NUM_STAGES_OPTIONS: for waves in WAVES_PER_EU_OPTIONS: - for matrix_instr_nonkdim in MATRIX_INSTR_NONKDIM_OPTIONS: - # Causal and non-causal configs - for m1 in BLOCK_M1_OPTIONS: - for n1 in BLOCK_N1_M2_OPTIONS: + for m1 in CAUSAL_BLOCK_M1_OPTIONS: + for n1 in CAUSAL_BLOCK_N1_M2_OPTIONS: m2 = n1 - for n2 in BLOCK_N2_OPTIONS: + for n2 in CAUSAL_BLOCK_N2_OPTIONS: # Ensure constraint - assert n1 == m2, f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + assert ( + n1 == m2 + ), f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + + # Skip configs where BLOCK_M2 % BLOCK_N2 != 0 + if m2 % n2 != 0: + continue + # Skip configs where BLOCK_N1 % BLOCK_M1 != 0 + if n1 % m1 != 0: + continue + for blk_slice in BLK_SLICE_FACTOR_OPTIONS: causal_autotune_configs.append( - triton.Config({ - "BLOCK_M1": m1, "BLOCK_N1": n1, - "BLOCK_M2": m2, "BLOCK_N2": n2, - "BLK_SLICE_FACTOR": blk_slice, - "waves_per_eu": waves, - "matrix_instr_nonkdim": matrix_instr_nonkdim - }, num_stages=num_stages, num_warps=num_warps) + triton.Config( + { + "BLOCK_M1": m1, + "BLOCK_N1": n1, + "BLOCK_M2": m2, + "BLOCK_N2": n2, + "BLK_SLICE_FACTOR": blk_slice, + "waves_per_eu": waves, + }, + num_stages=num_stages, + num_warps=num_warps, + ) ) + noncausal_autotune_configs = [] + for num_warps in NUM_WARPS_OPTIONS: + for num_stages in NUM_STAGES_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for m1 in NON_CAUSAL_BLOCK_M1_OPTIONS: + for n1 in NON_CAUSAL_BLOCK_N1_M2_OPTIONS: + m2 = n1 + for n2 in NON_CAUSAL_BLOCK_N2_OPTIONS: + # Ensure constraint + assert ( + n1 == m2 + ), f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + + # Skip configs where BLOCK_M2 % BLOCK_N2 != 0 + if m2 % n2 != 0: + continue + + # Skip configs where BLOCK_N1 % BLOCK_M1 != 0 + if n1 % m1 != 0: + continue + + for blk_slice in BLK_SLICE_FACTOR_OPTIONS: noncausal_autotune_configs.append( - triton.Config({ - "BLOCK_M1": m1, "BLOCK_N1": n1, - "BLOCK_M2": m2, "BLOCK_N2": n2, - "BLK_SLICE_FACTOR": blk_slice, - "waves_per_eu": waves, - "matrix_instr_nonkdim": matrix_instr_nonkdim - }, num_stages=num_stages, num_warps=num_warps) + triton.Config( + { + "BLOCK_M1": m1, + "BLOCK_N1": n1, + "BLOCK_M2": m2, + "BLOCK_N2": n2, + "BLK_SLICE_FACTOR": blk_slice, + "waves_per_eu": waves, + }, + num_stages=num_stages, + num_warps=num_warps, + ) ) - - # kernel keys - preprocess_autotune_keys = [ - "max_seqlen_q", - "ACTUAL_HEAD_DIM", "IS_VARLEN", - ] - - causal_autotune_keys = [ - "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] - - noncausal_autotune_keys = [ - "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] - - return (preprocess_autotune_configs, preprocess_autotune_keys), \ - (causal_autotune_configs, causal_autotune_keys), \ - (noncausal_autotune_configs, noncausal_autotune_keys) + return ( + (preprocess_autotune_configs, preprocess_autotune_keys), + (causal_autotune_configs, causal_autotune_keys), + (noncausal_autotune_configs, noncausal_autotune_keys), + ) +# os.environ["TRITON_PRINT_AUTOTUNING"] = "1" ( (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys), -) = get_bwd_configs() - - -# This function computes delta given output Out and gradient DO -# Here is the I/O shape: -# Out: (batch, nhead_q, max_seqlens_q, headDim) -# DO: (batch, nhead_q, max_seqlens_q, headDim) -# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at -@triton.jit -def _bwd_fused_atomics_preprocess( - o_ptr, - do_ptr, # noqa: E741 - delta_ptr, - stride_o_b, - stride_o_h, - stride_o_m, - stride_o_k, - stride_delta_b, - stride_delta_h, - stride_delta_m, - stride_descale_do_z, - cu_seqlens_q, - max_seqlen_q, - descale_do_ptr, - BLOCK_M: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, -): - pid_m = tl.program_id(0) # seqlen - bid = tl.program_id(1) # batch - hid = tl.program_id(2) # head - - # Handle varlen - q_start = 0 - seqlen_q = max_seqlen_q - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - seqlen_q = q_end - q_start - else: - q_start = 0 - seqlen_q = max_seqlen_q - - # Compute offsets - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - - # Offset O/DO by batch, head and q_start - offs = ( - bid * stride_o_b - + hid * stride_o_h - + q_start * stride_o_m - + offs_m[:, None] * stride_o_m - + offs_k[None, :] * stride_o_k - ) - - # create masks - mask_m = offs_m < seqlen_q - mask = mask_m[:, None] - PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 - if PADDED_HEAD: - mask &= offs_k[None, :] < BLOCK_D_MODEL - - # load [BLOCK_M, BLOCK_D_MODEL_POW2] - o = tl.load(o_ptr + offs, mask=mask, other=0.0) - do = tl.load(do_ptr + offs, mask=mask, other=0.0) - - # compute and write-back to delta - if IS_FP8: - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hid) - - # NOTE: do is in the fp8 range and o is not in fp8 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - - offs_delta = ( - bid * stride_delta_b - + hid * stride_delta_h - + q_start * stride_delta_m - + offs_m * stride_delta_m - ) - tl.store(delta_ptr + offs_delta, delta, mask=mask_m) +) = get_bwd_configs(AUTOTUNE) @triton.jit -def _bwd_fused_atomics_dq_inner( +def _bwd_dq_inner_split( dq, q, K, @@ -278,7 +530,6 @@ def _bwd_fused_atomics_dq_inner( descale_q, descale_k, descale_v, - descale_do, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_D_MODEL: tl.constexpr, @@ -347,7 +598,7 @@ def _bwd_fused_atomics_dq_inner( # dp if IS_FP8: - dp = tl.dot(do, vT) * descale_do * descale_v + dp = tl.dot(do.to(vT.type.element_ty), vT) * descale_v else: dp = tl.dot(do, vT) @@ -361,12 +612,11 @@ def _bwd_fused_atomics_dq_inner( # dq # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. if IS_FP8: - scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) - dq += ( - tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) - * descale_ds - * descale_k - ) + # Rewrite dq += ds @ kT.T as dq += (kT @ ds.T).T + # This puts FP8 tensor (kT) on LHS of dot product + # Cast the transposed ds to FP8 to match kT's dtype + ds_transposed = tl.trans(ds).to(kT.type.element_ty) + dq += tl.trans(tl.dot(kT, ds_transposed)) * descale_k else: dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) @@ -377,7 +627,7 @@ def _bwd_fused_atomics_dq_inner( @triton.jit -def _bwd_fused_atomics_dkdv_inner( +def _bwd_dkdv_inner_split( dk, dv, Q, @@ -406,7 +656,6 @@ def _bwd_fused_atomics_dkdv_inner( descale_q, descale_k, descale_v, - descale_do, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_D_MODEL: tl.constexpr, @@ -497,34 +746,16 @@ def _bwd_fused_atomics_dkdv_inner( # dV if ENABLE_DROPOUT: pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors( - pT_dropout, FP8_MAX - ) - dv += ( - tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) - * descale_p_dropout - * descale_do - ) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += ( - tl.dot((pT * scale_pT).to(do.type.element_ty), do) - * descale_pT - * descale_do - ) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) + dv += tl.dot(pT.to(do.type.element_ty), do) # Load delta Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) # Compute dP and dS if IS_FP8: - dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + dpT = tl.dot(v, tl.trans(do.to(v.type.element_ty))) * descale_v else: dpT = tl.dot(v, tl.trans(do)) @@ -536,12 +767,11 @@ def _bwd_fused_atomics_dkdv_inner( # compute dk if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += ( - tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) - * descale_dsT - * descale_q - ) + # Rewrite dk += dsT @ qT.T as dk += (qT @ dsT.T).T + # This puts FP8 tensor (qT) on LHS of dot product + # Cast the transposed dsT to FP8 to match qT's dtype + dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) + dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q else: dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) @@ -554,7 +784,7 @@ def _bwd_fused_atomics_dkdv_inner( @triton.jit -def _bwd_fused_atomics_dkdvdq_inner( +def _bwd_dkdvdq_inner_atomic( dk, dv, Q, @@ -586,7 +816,6 @@ def _bwd_fused_atomics_dkdvdq_inner( descale_q, descale_k, descale_v, - descale_do, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_D_MODEL: tl.constexpr, @@ -701,34 +930,16 @@ def _bwd_fused_atomics_dkdvdq_inner( # dV if ENABLE_DROPOUT: pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors( - pT_dropout, FP8_MAX - ) - dv += ( - tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) - * descale_p_dropout - * descale_do - ) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += ( - tl.dot((pT * scale_pT).to(do.type.element_ty), do) - * descale_pT - * descale_do - ) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) + dv += tl.dot(pT.to(do.type.element_ty), do) # Load delta Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) # Compute dP and dS if IS_FP8: - dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + dpT = tl.dot(v, tl.trans(do.to(v.type.element_ty))) * descale_v else: dpT = tl.dot(v, tl.trans(do)) @@ -740,12 +951,11 @@ def _bwd_fused_atomics_dkdvdq_inner( # compute dk if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += ( - tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) - * descale_dsT - * descale_q - ) + # Rewrite dk += dsT @ qT.T as dk += (qT @ dsT.T).T + # This puts FP8 tensor (qT) on LHS of dot product + # Cast the transposed dsT to FP8 to match qT's dtype + dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) + dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q else: dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) @@ -753,11 +963,9 @@ def _bwd_fused_atomics_dkdvdq_inner( # NOTE: Possible problems with the atomic add: contention, is inside a loop which has achieved bad perf before # (BLOCK_M, BLOCK_N) x (BLOCK_N, D) if IS_FP8: - dq_partial = ( - tl.dot((dsT * scale_dsT).to(k.dtype).T, k) * descale_dsT * descale_k - ) + dq_partial = tl.dot(dsT.to(k.type.element_ty).T, k) * descale_k else: - dq_partial = tl.dot(dsT.to(k.dtype).T, k) + dq_partial = tl.dot(dsT.to(k.type.element_ty).T, k) tl.atomic_add( dq_ptrs, dq_partial * sm_scale, @@ -769,7 +977,7 @@ def _bwd_fused_atomics_dkdvdq_inner( @triton.jit -def _bwd_kernel_fused_atomics_dkdvdq_causal( +def _bwd_kernel_fused_atomic_causal( q_ptr, k_ptr, v_ptr, @@ -814,7 +1022,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -826,7 +1033,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( descale_q_ptr, descale_k_ptr, descale_v_ptr, - descale_do_ptr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, BATCH, @@ -989,15 +1195,12 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( descale_v = tl.load( descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx ) - descale_do = tl.load( - descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx - ) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 # if unaligned start_m is negative, the current N-tile has no block on the # diagonal of causal mask, so everything have no causal mask - dk, dv = _bwd_fused_atomics_dkdvdq_inner( + dk, dv = _bwd_dkdvdq_inner_atomic( dk, dv, # output tensors q_ptr_adj, @@ -1029,7 +1232,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user MASK_BLOCK_M, BLOCK_N, # block dim BLOCK_D_MODEL, @@ -1045,7 +1247,7 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) end_m = start_m + num_steps * BLOCK_M - dk, dv = _bwd_fused_atomics_dkdvdq_inner( + dk, dv = _bwd_dkdvdq_inner_atomic( dk, dv, # output tensors q_ptr_adj, @@ -1077,7 +1279,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user BLOCK_M, BLOCK_N, # block dim BLOCK_D_MODEL, @@ -1103,7 +1304,7 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( @triton.jit -def _bwd_kernel_fused_atomics_dkdv_causal( +def _bwd_kernel_split_dkdv_causal( q_ptr, k_ptr, v_ptr, @@ -1143,7 +1344,6 @@ def _bwd_kernel_fused_atomics_dkdv_causal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1155,7 +1355,6 @@ def _bwd_kernel_fused_atomics_dkdv_causal( descale_q_ptr, descale_k_ptr, descale_v_ptr, - descale_do_ptr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, BLOCK_M: tl.constexpr, @@ -1307,15 +1506,12 @@ def _bwd_kernel_fused_atomics_dkdv_causal( descale_v = tl.load( descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx ) - descale_do = tl.load( - descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx - ) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 # if start_m is negative, the current N-tile has no block on the # diagonal of causal mask, so everything have no causal mask - dk, dv = _bwd_fused_atomics_dkdv_inner( + dk, dv = _bwd_dkdv_inner_split( dk, dv, # output tensors q_ptr_adj, @@ -1344,7 +1540,6 @@ def _bwd_kernel_fused_atomics_dkdv_causal( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user MASK_BLOCK_M, BLOCK_N, # block dim BLOCK_D_MODEL, @@ -1358,7 +1553,7 @@ def _bwd_kernel_fused_atomics_dkdv_causal( num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) end_m = start_m + num_steps * BLOCK_M - dk, dv = _bwd_fused_atomics_dkdv_inner( + dk, dv = _bwd_dkdv_inner_split( dk, dv, # output tensors q_ptr_adj, @@ -1387,7 +1582,6 @@ def _bwd_kernel_fused_atomics_dkdv_causal( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user BLOCK_M, BLOCK_N, # block dim BLOCK_D_MODEL, @@ -1412,7 +1606,7 @@ def _bwd_kernel_fused_atomics_dkdv_causal( @triton.jit -def _bwd_kernel_fused_atomics_dq_causal( +def _bwd_kernel_split_dq_causal( q_ptr, k_ptr, v_ptr, @@ -1451,7 +1645,6 @@ def _bwd_kernel_fused_atomics_dq_causal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1463,7 +1656,6 @@ def _bwd_kernel_fused_atomics_dq_causal( descale_q_ptr, descale_k_ptr, descale_v_ptr, - descale_do_ptr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, BLOCK_M: tl.constexpr, @@ -1582,11 +1774,8 @@ def _bwd_kernel_fused_atomics_dq_causal( descale_v = tl.load( descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx ) - descale_do = tl.load( - descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx - ) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) # Compute dQ for masked (diagonal) blocks. @@ -1594,7 +1783,7 @@ def _bwd_kernel_fused_atomics_dq_causal( # but inside each call to _bwd_dq_inner, from left to right), but that's # not due to anything important. I just wanted to reuse the loop # structure for dK & dV above as much as possible. - dq = _bwd_fused_atomics_dq_inner( + dq = _bwd_dq_inner_split( dq, q, k_ptr_adj, @@ -1625,7 +1814,6 @@ def _bwd_kernel_fused_atomics_dq_causal( descale_q, descale_k, descale_v, - descale_do, BLOCK_M, MASK_BLOCK_N, BLOCK_D_MODEL, @@ -1638,7 +1826,7 @@ def _bwd_kernel_fused_atomics_dq_causal( end_n -= num_steps * MASK_BLOCK_N num_steps = tl.cdiv(end_n, BLOCK_N) start_n = max(end_n - num_steps * BLOCK_N, 0) - dq = _bwd_fused_atomics_dq_inner( + dq = _bwd_dq_inner_split( dq, q, k_ptr_adj, @@ -1669,7 +1857,6 @@ def _bwd_kernel_fused_atomics_dq_causal( descale_q, descale_k, descale_v, - descale_do, BLOCK_M, BLOCK_N, BLOCK_D_MODEL, @@ -1692,7 +1879,7 @@ def _bwd_kernel_fused_atomics_dq_causal( @triton.jit -def _bwd_kernel_fused_atomics_dkdvdq_noncausal( +def _bwd_kernel_fused_atomic_noncausal( Q, K, V, @@ -1737,7 +1924,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_noncausal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1749,7 +1935,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_noncausal( descale_q_ptr, descale_k_ptr, descale_v_ptr, - descale_do_ptr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, BATCH, @@ -1842,17 +2027,18 @@ def _bwd_kernel_fused_atomics_dkdvdq_noncausal( ) if IS_FP8: - descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 start_m = 0 num_steps = tl.cdiv(seqlen_q, BLOCK_M) - dk, dv = _bwd_fused_atomics_dkdvdq_inner( + dk, dv = _bwd_dkdvdq_inner_atomic( dk, dv, Q_ptr, @@ -1884,7 +2070,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_noncausal( descale_q, descale_k, descale_v, - descale_do, BLOCK_M, BLOCK_N, BLOCK_D_MODEL, @@ -1909,7 +2094,7 @@ def _bwd_kernel_fused_atomics_dkdvdq_noncausal( @triton.jit -def _bwd_kernel_fused_atomics_dkdv_noncausal( +def _bwd_kernel_split_dkdv_noncausal( Q, K, V, @@ -1949,7 +2134,6 @@ def _bwd_kernel_fused_atomics_dkdv_noncausal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1961,7 +2145,6 @@ def _bwd_kernel_fused_atomics_dkdv_noncausal( descale_q_ptr, descale_k_ptr, descale_v_ptr, - descale_do_ptr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, BLOCK_M: tl.constexpr, @@ -2043,16 +2226,17 @@ def _bwd_kernel_fused_atomics_dkdv_noncausal( ) if IS_FP8: - descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 start_m = 0 num_steps = tl.cdiv(seqlen_q, BLOCK_M) - dk, dv = _bwd_fused_atomics_dkdv_inner( + dk, dv = _bwd_dkdv_inner_split( dk, dv, Q_ptr, @@ -2081,7 +2265,6 @@ def _bwd_kernel_fused_atomics_dkdv_noncausal( descale_q, descale_k, descale_v, - descale_do, BLOCK_M, BLOCK_N, BLOCK_D_MODEL, @@ -2105,7 +2288,7 @@ def _bwd_kernel_fused_atomics_dkdv_noncausal( @triton.jit -def _bwd_kernel_fused_atomics_dq_noncausal( +def _bwd_kernel_split_dq_noncausal( Q, K, V, @@ -2144,7 +2327,6 @@ def _bwd_kernel_fused_atomics_dq_noncausal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -2156,7 +2338,6 @@ def _bwd_kernel_fused_atomics_dq_noncausal( descale_q_ptr, descale_k_ptr, descale_v_ptr, - descale_do_ptr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, BLOCK_M: tl.constexpr, @@ -2229,18 +2410,19 @@ def _bwd_kernel_fused_atomics_dq_noncausal( # FP8 if IS_FP8: - descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 start_n = 0 end_n = seqlen_k num_steps = tl.cdiv(seqlen_k, BLOCK_N) dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dq = _bwd_fused_atomics_dq_inner( + dq = _bwd_dq_inner_split( dq, q, K, @@ -2271,7 +2453,6 @@ def _bwd_kernel_fused_atomics_dq_noncausal( descale_q, descale_k, descale_v, - descale_do, BLOCK_M, BLOCK_N, BLOCK_D_MODEL, @@ -2314,10 +2495,8 @@ def _bwd_preprocess( stride_delta_b, stride_delta_h, stride_delta_m, - stride_descale_do_z, cu_seqlens_q, max_seqlen_q, - Descale_do, PRE_BLOCK: tl.constexpr, HEAD_DIM_V: tl.constexpr, ACTUAL_HEAD_DIM_V: tl.constexpr, @@ -2365,14 +2544,8 @@ def _bwd_preprocess( o = tl.load(O + off_o, mask=mask_md, other=0.0) do = tl.load(DO + off_do, mask=mask_md, other=0.0) # compute and write-back to delta - if IS_FP8: - off_descale_do = bid * stride_descale_do_z + hid - descale_do = tl.load(Descale_do + off_descale_do) - - # NOTE: do is in the fp8 range and o is not in fp8 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + # NOTE: Both o and do are FP32 + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) off_delta = ( bid * stride_delta_b + hid * stride_delta_h @@ -2422,7 +2595,6 @@ def _bwd_dkdv_inner( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal ENABLE_DROPOUT: tl.constexpr, # activate dropout USE_ALIBI: tl.constexpr, @@ -2477,15 +2649,8 @@ def _bwd_dkdv_inner( + offs_m[None, :] * stride_dropoutm + offs_n[:, None] * stride_dropoutn ) - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = ( - offs_m[None, :] * stride_dropoutm - + offs_n[:, None] * stride_dropoutn - ) - dropout_mask = tl.load(curr_dropout_offset + dropout_offs, mask=mask_nm) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p dropout_scale = 1.0 / (1 - dropout_p) # Load m before computing qk to reduce pipeline stall. m = tl.load(M + offs_m * stride_lse_m, mask=mask_m, other=0.0) @@ -2529,27 +2694,9 @@ def _bwd_dkdv_inner( # Compute dV. if ENABLE_DROPOUT: pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors( - pT_dropout, FP8_MAX - ) - dv += ( - tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) - * descale_p_dropout - * descale_do - ) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += ( - tl.dot((pT * scale_pT).to(do.type.element_ty), do) - * descale_pT - * descale_do - ) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) + dv += tl.dot(pT.to(do.type.element_ty), do) if DEBUG_TRITON_DETAIL: if start_n == 256: @@ -2558,7 +2705,7 @@ def _bwd_dkdv_inner( Di = tl.load(D + offs_m * stride_delta_m, mask=mask_m) # Compute dP and dS. if IS_FP8: - dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + dpT = tl.dot(v, tl.trans(do.to(v.type.element_ty))) * descale_v else: dpT = tl.dot(v, tl.trans(do)) if ENABLE_DROPOUT: @@ -2566,12 +2713,11 @@ def _bwd_dkdv_inner( delta_i = Di[None, :] dsT = pT * (dpT - delta_i) if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += ( - tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) - * descale_dsT - * descale_q - ) + # Rewrite dk += dsT @ qT.T as dk += (qT @ dsT.T).T + # This puts FP8 tensor (qT) on LHS of dot product + # Cast the transposed dsT to FP8 to match qT's dtype + dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) + dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q else: dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) # Increment pointers. @@ -2624,7 +2770,6 @@ def _bwd_dq_inner( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user MASK: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -2688,15 +2833,8 @@ def _bwd_dq_inner( + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn ) - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = ( - offs_m[:, None] * stride_dropoutm - + offs_n[None, :] * stride_dropoutn - ) - dropout_mask = tl.load(curr_dropout_offset + dropout_offs, mask=mask_mn) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p dropout_scale = 1 / (1 - dropout_p) if IS_FP8: @@ -2724,7 +2862,7 @@ def _bwd_dq_inner( p = tl.where(mask, p, 0.0) # Compute dP and dS. if IS_FP8: - dp = tl.dot(do, vT) * descale_do * descale_v + dp = tl.dot(do.to(vT.type.element_ty), vT) * descale_v else: dp = tl.dot(do, vT) if ENABLE_DROPOUT: @@ -2734,12 +2872,11 @@ def _bwd_dq_inner( # Compute dQ. # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. if IS_FP8: - scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) - dq += ( - tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) - * descale_ds - * descale_k - ) + # Rewrite dq += ds @ kT.T as dq += (kT @ ds.T).T + # This puts FP8 tensor (kT) on LHS of dot product + # Cast the transposed ds to FP8 to match kT's dtype + ds_transposed = tl.trans(ds).to(kT.type.element_ty) + dq += tl.trans(tl.dot(kT, ds_transposed)) * descale_k else: dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) # Increment pointers. @@ -2755,7 +2892,7 @@ def _bwd_dq_inner( use_cuda_graph=True, ) @triton.jit -def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), batch) +def bwd_kernel_fused_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), batch) Q, K, V, @@ -2807,7 +2944,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, stride_az, stride_ah, HQ, @@ -2826,7 +2962,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b Descale_q, Descale_k, Descale_v, - Descale_do, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, BLOCK_M2: tl.constexpr, @@ -2842,7 +2977,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, USE_SEQUSED: tl.constexpr, # Add flag for seqused DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, @@ -2944,8 +3078,8 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b + offs_d_v[None, :] * stride_vd ) # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_k, other=0.0) - v = tl.load(V + adj_v, mask=mask_v, other=0.0) + k = tl.load(K + adj_k, mask=mask_k) + v = tl.load(V + adj_v, mask=mask_v) # If MQA / GQA, set the K and V head offsets appropriately. # hqid = hkid for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): @@ -2994,12 +3128,13 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b ) if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR # bound the masked operation to q len so it does not have to wast cycles @@ -3052,7 +3187,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b descale_q, descale_k, descale_v, - descale_do, MASK=True, # causal masking ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout USE_ALIBI=USE_ALIBI, @@ -3113,7 +3247,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b descale_q, descale_k, descale_v, - descale_do, MASK=False, # causal masking ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout USE_ALIBI=USE_ALIBI, @@ -3212,12 +3345,13 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) dq = _bwd_dq_inner( @@ -3259,7 +3393,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b descale_q, descale_k, descale_v, - descale_do, MASK=True, # ENABLE_DROPOUT=ENABLE_DROPOUT, USE_ALIBI=USE_ALIBI, @@ -3315,7 +3448,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b descale_q, descale_k, descale_v, - descale_do, MASK=False, ENABLE_DROPOUT=ENABLE_DROPOUT, USE_ALIBI=USE_ALIBI, @@ -3339,7 +3471,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b use_cuda_graph=True, ) @triton.jit -def bwd_kernel_noncausal( +def bwd_kernel_fused_noncausal( Q, K, V, @@ -3391,7 +3523,6 @@ def bwd_kernel_noncausal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, stride_az, stride_ah, HQ, @@ -3410,7 +3541,6 @@ def bwd_kernel_noncausal( Descale_q, Descale_k, Descale_v, - Descale_do, BLOCK_M1: tl.constexpr, # 32 BLOCK_N1: tl.constexpr, # 128 BLOCK_M2: tl.constexpr, # 128 @@ -3426,7 +3556,6 @@ def bwd_kernel_noncausal( USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, USE_SEQUSED: tl.constexpr, # Add flag for seqused DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, @@ -3501,8 +3630,8 @@ def bwd_kernel_noncausal( + offs_d_v[None, :] * stride_vd ) # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_k, other=0.0) - v = tl.load(V + adj_v, mask=mask_v, other=0.0) + k = tl.load(K + adj_k, mask=mask_k) + v = tl.load(V + adj_v, mask=mask_v) # If MQA / GQA, set the K and V head offsets appropriately. for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): # offset input and output tensor by batch and Q/K heads @@ -3536,12 +3665,13 @@ def bwd_kernel_noncausal( ) if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 # because there is no causal, we always start from the beginning start_m = 0 @@ -3583,7 +3713,6 @@ def bwd_kernel_noncausal( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user MASK=False, # causal masking ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout USE_ALIBI=USE_ALIBI, @@ -3657,12 +3786,13 @@ def bwd_kernel_noncausal( m = m[:, None] if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 # start can only be 0 at minimum start_n = 0 @@ -3709,7 +3839,6 @@ def bwd_kernel_noncausal( descale_q, descale_k, descale_v, - descale_do, MASK=False, ENABLE_DROPOUT=ENABLE_DROPOUT, USE_ALIBI=USE_ALIBI, @@ -3738,7 +3867,8 @@ def is_contiguous(x, name): DEBUG_TRITON_DETAIL: bool = False -def attention_backward_triton_split_fused_no_atomics_impl( +def attention_backward_triton_impl( + *, do: torch.Tensor, q: torch.Tensor, k: torch.Tensor, @@ -3748,6 +3878,7 @@ def attention_backward_triton_split_fused_no_atomics_impl( dq: torch.Tensor, dk: torch.Tensor, dv: torch.Tensor, + delta: torch.Tensor, sm_scale: float, alibi_slopes: Optional[torch.Tensor], causal: bool, @@ -3756,22 +3887,13 @@ def attention_backward_triton_split_fused_no_atomics_impl( cu_seqlens_k: Optional[torch.Tensor], max_seqlen_q: Optional[int], max_seqlen_k: Optional[int], - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - use_exp2: bool, - # fp8 - descale_q: Optional[torch.Tensor], - descale_k: Optional[torch.Tensor], - descale_v: Optional[torch.Tensor], - descale_o: Optional[torch.Tensor], - descale_do: Optional[torch.Tensor], - descale_dq: Optional[torch.Tensor], - descale_dk: Optional[torch.Tensor], - descale_dv: Optional[torch.Tensor], - # seqused for FA v3 seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + philox_seed: Optional[int] = None, + philox_offset: Optional[int] = None, + use_exp2: bool = True, + mode: Literal["fused", "fused_atomic", "split"] = "fused", ): # get params, strides and shape IS_VARLEN = layout == "thd" @@ -3784,9 +3906,7 @@ def attention_backward_triton_split_fused_no_atomics_impl( assert ( q.device == k.device == v.device == o.device == do.device == softmax_lse.device ), f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}, do={do.device}, softmax_lse={softmax_lse.device}" - assert ( - q.dtype == k.dtype == v.dtype == do.dtype - ), "q, k, v, do must have the same dtype" + assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype" current_device = torch.cuda.current_device() assert ( q.is_cuda and q.device.index == current_device @@ -3980,42 +4100,34 @@ def attention_backward_triton_split_fused_no_atomics_impl( stride_dob, stride_dom, stride_doh, stride_dod = do.stride() stride_lse_b, stride_lse_h, stride_lse_m = softmax_lse.stride() - # fp8 setup - moved after all assertions - IS_FP8 = is_fp8(q) + # fp8 + IS_FP8 = is_fp8([q, k, v]) if IS_FP8: FP8_MAX = torch.finfo(q.dtype).max - # we already asserted that do, q, k, v all have the same dtype, so no need to check each one - if is_fp8(o): - FP8_OUTPUT = True - assert ( - descale_o is not None - ), f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." - assert ( - descale_dq is not None - ), f"descale_dq is None. In fp8, you need to pass a tensor for descale_dq along with a tensor dq." - assert ( - descale_dk is not None - ), f"descale_dk is None. In fp8, you need to pass a tensor for descale_dk along with a tensor dk." - assert ( - descale_dv is not None - ), f"descale_dv is None. In fp8, you need to pass a tensor for descale_dv along with a tensor dv." - else: - FP8_OUTPUT = False + + warnings.warn( + "FP8 tensors detected in backward pass. Backward pass supports FP8 inputs but " + "descaling factors will default to 1.0.", + UserWarning, + ) + + # For GQA/MQA, q_descale should be shaped (batch, nheads_k) to match forward pass + descale_q = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) + + descale_k = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) + + descale_v = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None - stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None - stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None if DEBUG: - print(f"FP8 path triggered (FP8_OUTPUT={FP8_OUTPUT})") + print(f"FP8 path triggered in bwd.py") else: FP8_MAX = None - FP8_OUTPUT = False - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( - stride_descale_o_z - ) = stride_descale_do_z = None + descale_q = descale_k = descale_v = None + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = None # alibi setup use_alibi, (stride_az, stride_ah) = ( @@ -4032,11 +4144,18 @@ def attention_backward_triton_split_fused_no_atomics_impl( ACTUAL_HEAD_DIM_QK = head_size_qk ACTUAL_HEAD_DIM_V = head_size_v - # init delta + # Validate pre-allocated delta tensor if IS_VARLEN: # Shape expected by interface varlen backward: (Hq, Total_Q) total_q, _, _ = q.shape - delta = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + assert ( + delta.shape[0] == nheads_q + ), f"delta.shape[0] ({delta.shape[0]}) must equal nheads_q ({nheads_q})" + assert ( + delta.shape[1] >= total_q + ), f"delta.shape[1] ({delta.shape[1]}) must be >= total_q ({total_q})" + assert delta.dtype == torch.float32, f"delta must be float32, got {delta.dtype}" + assert delta.device == q.device, f"delta must be on same device as q" stride_delta_b, stride_delta_h, stride_delta_m = ( 0, delta.stride(0), @@ -4045,9 +4164,17 @@ def attention_backward_triton_split_fused_no_atomics_impl( else: # Shape expected by dense backward: (B, Hq, Sq) seqlen_q = q.shape[1] - delta = torch.zeros( - (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 - ) + assert ( + delta.shape[0] == batch + ), f"delta.shape[0] ({delta.shape[0]}) must equal batch ({batch})" + assert ( + delta.shape[1] == nheads_q + ), f"delta.shape[1] ({delta.shape[1]}) must equal nheads_q ({nheads_q})" + assert ( + delta.shape[2] >= seqlen_q + ), f"delta.shape[2] ({delta.shape[2]}) must be >= seqlen_q ({seqlen_q})" + assert delta.dtype == torch.float32, f"delta must be float32, got {delta.dtype}" + assert delta.device == q.device, f"delta must be on same device as q" stride_delta_b, stride_delta_h, stride_delta_m = delta.stride() pre_grid = lambda META: ( @@ -4070,10 +4197,8 @@ def attention_backward_triton_split_fused_no_atomics_impl( stride_delta_b, stride_delta_h, stride_delta_m, - stride_descale_do_z, cu_seqlens_q, max_seqlen_q, - descale_do, HEAD_DIM_V=HEAD_DIM_V, ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, IS_VARLEN=IS_VARLEN, @@ -4094,364 +4219,205 @@ def attention_backward_triton_split_fused_no_atomics_impl( dtype=torch.float32, ) - if DROPOUT_USE_PYTORCH: - if not IS_VARLEN: - dropout_mask = create_dropout_mask( - dropout_p, - (batch, nheads_q, max_seqlen_q, max_seqlen_k), - seed=philox_seed, - ) - else: - dropout_mask = create_dropout_mask_varlen( - dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed - ) stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = ( dropout_mask.stride() ) - seqlen = max(max_seqlen_q, max_seqlen_k) - grid = lambda META: ( - nheads_k, - (seqlen + META["BLOCK_N1"] - 1) // META["BLOCK_N1"], - batch, - ) - if causal: - if DEBUG_TRITON: - print(f"bwd_kernel: grid = {grid}") # noqa: E701 - bwd_kernel_causal[grid]( - q, - k, - v, - sm_scale, - do, - dq, - dk, - dv, - softmax_lse, - delta, - stride_qb, - stride_qh, - stride_qm, - stride_qd, - stride_kb, - stride_kh, - stride_kn, - stride_kd, - stride_vb, - stride_vh, - stride_vn, - stride_vd, - stride_dqb, - stride_dqh, - stride_dqm, - stride_dqd, - stride_dkb, - stride_dkh, - stride_dkn, - stride_dkd, - stride_dvb, - stride_dvh, - stride_dvn, - stride_dvd, - stride_lse_b, - stride_lse_h, - stride_lse_m, - stride_delta_b, - stride_delta_h, - stride_delta_m, - stride_dob, - stride_doh, - stride_dom, - stride_dod, - stride_dropoutb, - stride_dropouth, - stride_dropoutm, - stride_dropoutn, - stride_descale_q_z, - stride_descale_k_z, - stride_descale_v_z, - stride_descale_do_z, - stride_az, - stride_ah, - nheads_q, - nheads_k, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, # Pass seqused tensors - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - descale_q, - descale_k, - descale_v, - descale_do, - HEAD_DIM_QK=HEAD_DIM_QK, - HEAD_DIM_V=HEAD_DIM_V, - ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, - ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, - USE_SEQUSED=( - seqused_q is not None or seqused_k is not None - ), # Add flag for seqused - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - else: - bwd_kernel_noncausal[grid]( - q, - k, - v, - sm_scale, - do, - dq, - dk, - dv, - softmax_lse, - delta, - stride_qb, - stride_qh, - stride_qm, - stride_qd, - stride_kb, - stride_kh, - stride_kn, - stride_kd, - stride_vb, - stride_vh, - stride_vn, - stride_vd, - stride_dqb, - stride_dqh, - stride_dqm, - stride_dqd, - stride_dkb, - stride_dkh, - stride_dkn, - stride_dkd, - stride_dvb, - stride_dvh, - stride_dvn, - stride_dvd, - stride_lse_b, - stride_lse_h, - stride_lse_m, - stride_delta_b, - stride_delta_h, - stride_delta_m, - stride_dob, - stride_doh, - stride_dom, - stride_dod, - stride_dropoutb, - stride_dropouth, - stride_dropoutm, - stride_dropoutn, - stride_descale_q_z, - stride_descale_k_z, - stride_descale_v_z, - stride_descale_do_z, - stride_az, - stride_ah, - nheads_q, + # Choose which kernels to call based on mode + if mode == "fused": + seqlen = max(max_seqlen_q, max_seqlen_k) + grid = lambda META: ( nheads_k, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, # Pass seqused tensors - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - descale_q, - descale_k, - descale_v, - descale_do, - HEAD_DIM_QK=HEAD_DIM_QK, - HEAD_DIM_V=HEAD_DIM_V, - ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, - ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, - USE_SEQUSED=( - seqused_q is not None or seqused_k is not None - ), # Add flag for seqused - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - return delta - - -def attention_backward_triton_fused_atomics_impl( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - philox_seed: Optional[int] = 0, - philox_offset: Optional[int] = 0, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None, - fused: bool = False, - # seqused for FA v3 (currently ignored in this implementation) - seqused_q: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, -): - IS_FP8 = is_fp8(q) - if IS_FP8: - FP8_MAX = torch.finfo(q.dtype).max - descale_strides = ( - descale_q.stride(0), - descale_k.stride(0), - descale_v.stride(0), - descale_do.stride(0), - ) - - if DEBUG: - print(f"FP8 path triggered") - else: - FP8_MAX = None - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( - stride_descale_do_z - ) = None - descale_strides = ( - stride_descale_q_z, - stride_descale_k_z, - stride_descale_v_z, - stride_descale_do_z, - ) - - IS_VARLEN = True if cu_seqlens_q is not None else False - - # get strides and shape - if IS_VARLEN: - # Layout for q,k,v is thd ie [total tokens, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = ( - len(cu_seqlens_q) - 1, - max_seqlen_q, - q.shape[1], - q.shape[2], + (seqlen + META["BLOCK_N1"] - 1) // META["BLOCK_N1"], + batch, ) - seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - dq_strides = (0, dq.stride(1), dq.stride(0), dq.stride(2)) - dk_strides = (0, dk.stride(1), dk.stride(0), dk.stride(2)) - dv_strides = (0, dv.stride(1), dv.stride(0), dv.stride(2)) - do_strides = (0, do.stride(1), do.stride(0), do.stride(2)) - else: - # Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = q.shape - seqlen_k, num_k_heads = k.shape[1], k.shape[2] - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - dq_strides = (dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3)) - dk_strides = (dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3)) - dv_strides = (dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3)) - do_strides = (do.stride(0), do.stride(2), do.stride(1), do.stride(3)) - - # BLOCK_D_MODEL, BLOCK_D_MODEL_POW2 - # padding for head_dim. Power of 2 or 16 - BLOCK_D_MODEL_POW2 = triton.next_power_of_2(head_sz) - BLOCK_D_MODEL_POW2 = max(BLOCK_D_MODEL_POW2, 16) - - # Configs - # PRE_BLOCK, BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 - # BLK_SLICE_FACTOR - NUM_WARPS, NUM_STAGES = 4, 1 - WAVES_PER_EU = 1 - PRE_BLOCK = 128 - # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 - BLK_SLICE_FACTOR = 2 - - # init delta - delta = torch.zeros_like(softmax_lse) - if IS_VARLEN: - # [total_tokens, num_q_heads, seqlen_q] - delta_strides = (0, delta.stride(1), delta.stride(0)) - else: - # [batch, num_q_heads, seqlen_q] - delta_strides = delta.stride() - - # preprocess - # compute D(delta) = rowsum(dO*O). Note, multiplication is element-wise. - pre_grid = (triton.cdiv(max_seqlen_q, PRE_BLOCK), batch, num_q_heads) - _bwd_fused_atomics_preprocess[pre_grid]( - o, - do, - delta, - *o_strides, - *delta_strides, - descale_strides[3], - cu_seqlens_q, - max_seqlen_q, - descale_do, - BLOCK_M=PRE_BLOCK, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - ) - - # dropout_mask - use_dropout = dropout_p > 0.0 - if use_dropout: - dropout_mask = torch.zeros( - (batch, num_q_heads, max_seqlen_q, max_seqlen_k), - device=q.device, - dtype=torch.float32, - ) - dropout_strides = dropout_mask.stride() - else: - dropout_mask = None - dropout_strides = (0, 0, 0, 0) - - grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, num_k_heads) - grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, num_k_heads) + if causal: + if DEBUG_TRITON: + print(f"bwd_kernel: grid = {grid}") # noqa: E701 + bwd_kernel_fused_causal[grid]( + q, + k, + v, + sm_scale, + do, + dq, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_az, + stride_ah, + nheads_q, + nheads_k, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Pass seqused tensors + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, + descale_q, + descale_k, + descale_v, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + USE_SEQUSED=( + seqused_q is not None or seqused_k is not None + ), # Add flag for seqused + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + else: + bwd_kernel_fused_noncausal[grid]( + q, + k, + v, + sm_scale, + do, + dq, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_az, + stride_ah, + nheads_q, + nheads_k, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Pass seqused tensors + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, + descale_q, + descale_k, + descale_v, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + USE_SEQUSED=( + seqused_q is not None or seqused_k is not None + ), # Add flag for seqused + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + elif mode == "fused_atomic": + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 + BLK_SLICE_FACTOR = 2 + BLOCK_D_MODEL_POW2 = max(triton.next_power_of_2(HEAD_DIM_QK), 16) - if ( - fused - ): # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups + grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) + grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) + # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups BLOCK_N = ( 128 if BLOCK_D_MODEL_POW2 < 160 else 64 ) # larger head sizes lead to oom @@ -4465,10 +4431,10 @@ def attention_backward_triton_fused_atomics_impl( } num_k_pids = (max_seqlen_k + BLOCK_N - 1) // BLOCK_N - grid_dkdvdq = (batch * num_k_heads * num_k_pids,) + grid_dkdvdq = (batch * nheads_k * num_k_pids,) if causal: - _bwd_kernel_fused_atomics_dkdvdq_causal[grid_dkdvdq]( + _bwd_kernel_fused_atomic_causal[grid_dkdvdq]( q, k, v, @@ -4479,15 +4445,36 @@ def attention_backward_triton_fused_atomics_impl( dq, softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -4499,12 +4486,11 @@ def attention_backward_triton_fused_atomics_impl( descale_q, descale_k, descale_v, - descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, BATCH=batch, NUM_K_PIDS=num_k_pids, - BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL=HEAD_DIM_QK, BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, @@ -4513,7 +4499,7 @@ def attention_backward_triton_fused_atomics_impl( **config, ) else: - _bwd_kernel_fused_atomics_dkdvdq_noncausal[grid_dkdvdq]( + _bwd_kernel_fused_atomic_noncausal[grid_dkdvdq]( q, k, v, @@ -4524,15 +4510,36 @@ def attention_backward_triton_fused_atomics_impl( dq, softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -4544,12 +4551,11 @@ def attention_backward_triton_fused_atomics_impl( descale_q, descale_k, descale_v, - descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, BATCH=batch, NUM_K_PIDS=num_k_pids, - BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL=HEAD_DIM_QK, BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, @@ -4557,302 +4563,282 @@ def attention_backward_triton_fused_atomics_impl( FP8_MAX=FP8_MAX, **config, ) + elif mode == "split": + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 + BLK_SLICE_FACTOR = 2 + BLOCK_D_MODEL_POW2 = max(triton.next_power_of_2(HEAD_DIM_QK), 16) - return delta - - # split kernels solution: one kernel computes dk, dv and the other computes dq - - if causal: - _bwd_kernel_fused_atomics_dkdv_causal[grid_dkdv]( - q, - k, - v, - sm_scale, - do, - dk, - dv, - softmax_lse, - delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - descale_q, - descale_k, - descale_v, - descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M1, - BLOCK_N=BLOCK_N1, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - _bwd_kernel_fused_atomics_dq_causal[grid_dq]( - q, - k, - v, - sm_scale, - do, - dq, - softmax_lse, - delta, - *q_strides, - *k_strides, - *v_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - descale_q, - descale_k, - descale_v, - descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M2, - BLOCK_N=BLOCK_N2, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - else: - _bwd_kernel_fused_atomics_dkdv_noncausal[grid_dkdv]( - q, - k, - v, - sm_scale, - do, - dk, - dv, - softmax_lse, - delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - descale_q, - descale_k, - descale_v, - descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M1, - BLOCK_N=BLOCK_N1, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - - _bwd_kernel_fused_atomics_dq_noncausal[grid_dq]( - q, - k, - v, - sm_scale, - do, - dq, - softmax_lse, - delta, - *q_strides, - *k_strides, - *v_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - descale_q, - descale_k, - descale_v, - descale_do, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M2, - BLOCK_N=BLOCK_N2, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - - return delta - + grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) + grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) -def attention_backward_triton_impl( - *, - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - layout: str, - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - seqused_q: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - philox_seed: Optional[int] = None, - philox_offset: Optional[int] = None, - use_exp2: bool = True, - mode: str = "fused_no_atomics", -) -> torch.Tensor: - """Unified backward interface dispatching to atomics or no-atomics implementation. - - Parameters mirror the superset of the two legacy interfaces. The public API should - call ONLY this function going forward. - mode: 'fused_atomics' or 'fused_no_atomics'; layout: 'bshd' or 'thd'; use_exp2 retained for parity. - """ - # Enforce supported dtypes (mirror Hopper behavior: FP8 forward-only) - supported_dtypes = {torch.float16, torch.bfloat16, torch.float32} - for name, t in {"q": q, "k": k, "v": v, "o": o, "do": do}.items(): - if t.dtype not in supported_dtypes: - raise TypeError( - f"Backward only supports fp16/bf16/fp32; tensor '{name}' has dtype {t.dtype}" + if causal: + _bwd_kernel_split_dkdv_causal[grid_dkdv]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + _bwd_kernel_split_dq_causal[grid_dq]( + q, + k, + v, + sm_scale, + do, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + else: + _bwd_kernel_split_dkdv_noncausal[grid_dkdv]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, ) - if mode == "fused_atomics": - # Atomics path ignores layout & use_exp2; pass varlen metadata directly. - return attention_backward_triton_fused_atomics_impl( - do, - q, - k, - v, - o, - softmax_lse, - dq, - dk, - dv, - sm_scale, - alibi_slopes, - causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q if max_seqlen_q is not None else q.shape[1], - max_seqlen_k if max_seqlen_k is not None else k.shape[1], - dropout_p, - philox_seed or 0, - philox_offset or 0, - None, - None, - None, - None, - True, # fused flag - None, - None, - ) - elif mode == "fused_no_atomics": - return attention_backward_triton_split_fused_no_atomics_impl( - do, - q, - k, - v, - o, - softmax_lse, - dq, - dk, - dv, - sm_scale, - alibi_slopes, - causal, - layout, # layout required here - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - use_exp2, - None, - None, - None, - None, - None, - None, - None, - None, - seqused_q, - seqused_k, - ) + _bwd_kernel_split_dq_noncausal[grid_dq]( + q, + k, + v, + sm_scale, + do, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) else: raise ValueError( - f"Unknown backward mode '{mode}'. Expected 'fused_atomics' or 'fused_no_atomics'." + f"Unknown backward mode '{mode}'. Expected 'split', 'fused_atomic' or 'fused'." ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index bb7edad3494..4645dcc97fe 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -1,3 +1,5 @@ +import os +import warnings import torch import triton import triton.language as tl @@ -5,11 +7,13 @@ from .utils import ( DEBUG, AUTOTUNE, + get_arch, get_padded_headsize, get_shape_and_strides_from_layout, apply_rotary, is_cdna, is_fp8, + get_recommended_fp8_dtype, ) @@ -842,6 +846,7 @@ def attention_forward_decode_triton_impl( k_new: Optional[torch.Tensor], v_new: Optional[torch.Tensor], out: torch.Tensor, + softmax_lse: torch.Tensor, sm_scale: float, causal: bool, window_size_left: int, @@ -1025,13 +1030,13 @@ def attention_forward_decode_triton_impl( stride_kn_h, stride_kn_n, stride_kn_d, - ) = (None, None, None, None), (None, None, None, None) + ) = (None, None, None, None,), (None, None, None, None) (_, seqlen_vn, nheads_vn, dim_vn), ( stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d, - ) = (None, None, None, None), (None, None, None, None) + ) = (None, None, None, None,), (None, None, None, None) (_, seqlen_o, nheads_o, dim_o), (stride_oz, stride_oh, stride_om, stride_od) = ( get_shape_and_strides_from_layout(out, layout) ) @@ -1100,11 +1105,27 @@ def attention_forward_decode_triton_impl( dtype=torch.float32, device=q.device, ) - lse = torch.empty( - (batch_size * n_group_q * heads_per_group_q, seqlen_q), - dtype=torch.float32, - device=q.device, - ) + + # Validate pre-allocated softmax_lse tensor + # Expected shape after view: (batch_size, n_group_q * heads_per_group_q, seqlen_q) + # Internal shape: (batch_size * n_group_q * heads_per_group_q, seqlen_q) + expected_h_total = batch_size * n_group_q * heads_per_group_q + assert ( + softmax_lse.shape[0] == batch_size + ), f"softmax_lse.shape[0] ({softmax_lse.shape[0]}) must equal batch_size ({batch_size})" + assert ( + softmax_lse.shape[1] == n_group_q * heads_per_group_q + ), f"softmax_lse.shape[1] ({softmax_lse.shape[1]}) must equal n_group_q * heads_per_group_q ({n_group_q * heads_per_group_q})" + assert ( + softmax_lse.shape[2] >= seqlen_q + ), f"softmax_lse.shape[2] ({softmax_lse.shape[2]}) must be >= seqlen_q ({seqlen_q})" + assert ( + softmax_lse.dtype == torch.float32 + ), f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert softmax_lse.device == q.device, f"softmax_lse must be on same device as q" + + # Create internal lse view for kernel use + lse = softmax_lse.view(expected_h_total, -1)[:, :seqlen_q].contiguous() # get intermediate tensor strides stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d = out_splitk.stride() @@ -1118,11 +1139,20 @@ def attention_forward_decode_triton_impl( stride_bt_b, stride_bt_s = 0, 0 # FP8 support - IS_FP8 = is_fp8(q) + IS_FP8 = is_fp8([q, k_cache, v_cache]) if IS_FP8: + rec_dtype = get_recommended_fp8_dtype(q) + if ( + q.dtype != rec_dtype + or k_cache.dtype != rec_dtype + or v_cache.dtype != rec_dtype + ): + arch = get_arch() + warnings.warn( + f"Use {rec_dtype} data type on {arch}. Got q: {q.dtype}, k: {k_cache.dtype}, v: {v_cache.dtype}", + UserWarning, + ) if (q_descale is None) or (k_descale is None) or (v_descale is None): - import warnings - warnings.warn( "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning, @@ -1140,6 +1170,23 @@ def attention_forward_decode_triton_impl( v_descale = torch.ones( batch_size, nheads_vc, dtype=torch.float32, device=q.device ) + else: + # Enforce exact expected shapes; no reshaping or normalization. + assert ( + q_descale.dim() == 2 + and q_descale.shape[0] == batch_size + and q_descale.shape[1] == nheads_kc + ), f"q_descale expected shape ({batch_size}, {nheads_kc}) got {tuple(q_descale.shape)}" + assert ( + k_descale.dim() == 2 + and k_descale.shape[0] == batch_size + and k_descale.shape[1] == nheads_kc + ), f"k_descale expected shape ({batch_size}, {nheads_kc}) got {tuple(k_descale.shape)}" + assert ( + v_descale.dim() == 2 + and v_descale.shape[0] == batch_size + and v_descale.shape[1] == nheads_kc + ), f"v_descale expected shape ({batch_size}, {nheads_kc}) got {tuple(v_descale.shape)}" stride_q_descale_z, stride_q_descale_h = q_descale.stride() stride_k_descale_z, stride_k_descale_h = k_descale.stride() stride_v_descale_z, stride_v_descale_h = v_descale.stride() @@ -1355,5 +1402,3 @@ def attention_forward_decode_triton_impl( PADDED_HEAD=is_padded_head, num_warps=num_warps_reduce, ) - - return lse.view(batch_size, n_group_q * heads_per_group_q, seqlen_q) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index d1036f98c3f..81a1de19f20 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -1,4 +1,5 @@ import os +import warnings import torch import triton import triton.language as tl @@ -6,116 +7,22 @@ from .utils import ( DEBUG, AUTOTUNE, - DROPOUT_USE_PYTORCH, - DROPOUT_DUMP, compute_alibi_block, compute_fp8_scaling_factors, get_arch, + get_cu_count, is_cdna, is_fp8, is_rdna, - create_dropout_mask, apply_rotary, + get_recommended_fp8_dtype, ) -# NOTE: triton fails to import tl.constexprs so create them here for the file -tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) -tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) - - -# ------------------------------- -# Autotune -# ------------------------------- -def get_fwd_prefill_cdna_autotune_configs(): - return [ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - # Fall-back config. - triton.Config( - {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - ], [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_BLOCK_DMODEL_QK", - "ACTUAL_BLOCK_DMODEL_V", - "IS_VARLEN", - "HQ", - "HK", - ] -def get_fwd_prefill_rdna_autotune_configs(): - return [ - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 4, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 4, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 4, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - # Fall-back config. - triton.Config( - {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - ], [ +def get_fwd_configs(autotune: bool): + configs = [] + keys = [ "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", @@ -127,59 +34,129 @@ def get_fwd_prefill_rdna_autotune_configs(): "HK", ] - -def get_fwd_prefill_autotune_configs(): - if AUTOTUNE: - if is_rdna(): - return get_fwd_prefill_rdna_autotune_configs() - elif is_cdna(): - return get_fwd_prefill_cdna_autotune_configs() - else: - raise ValueError("Unknown Device Type") - else: + # get best config for the architecture + if not autotune: arch = get_arch() if arch == "gfx950": - default_config = triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, + configs.append( + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ) ) - elif ( - arch == "gfx942" and False - ): # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. - default_config = triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, + elif arch == "gfx942": + if get_cu_count() < 304: + configs.extend( + [ + # best fp8 config + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + # best f16 config + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 32, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=2, + num_warps=4, + ), + ] + ) + else: + configs.append( + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ) + ) + elif arch in ( + "gfx1030", + "gfx1100", + "gfx1101", + "gfx1102", + "gfx1200", + "gfx1201", + ): # RDNA architectures + configs.append( + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=2, + ) ) else: - default_config = triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, + configs.append( + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ) ) - return [default_config], [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_BLOCK_DMODEL_QK", - "ACTUAL_BLOCK_DMODEL_V", - "IS_VARLEN", - "HQ", - "HK", - ] - - -fwd_prefill_autotune_configs, fwd_prefill_autotune_keys = ( - get_fwd_prefill_autotune_configs() -) + return configs, keys + + # ===================== Autotune Sweep ===================== + BLOCK_M_OPTIONS = [128, 64, 32] + BLOCK_N_OPTIONS = [128, 64, 32] + NUM_WARPS_OPTIONS = [2, 4, 8] + NUM_STAGES_OPTIONS = [1, 2] + WAVES_PER_EU_OPTIONS = [4, 2, 1] + PRE_LOAD_V_OPTIONS = [False] + for bm in BLOCK_M_OPTIONS: + for bn in BLOCK_N_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for nw in NUM_WARPS_OPTIONS: + for ns in NUM_STAGES_OPTIONS: + for preload_v in PRE_LOAD_V_OPTIONS: + configs.append( + triton.Config( + { + "BLOCK_M": bm, + "BLOCK_N": bn, + "waves_per_eu": waves, + "PRE_LOAD_V": preload_v, + }, + num_stages=ns, + num_warps=nw, + ) + ) + + return configs, keys + + +fwd_prefill_autotune_configs, fwd_prefill_autotune_keys = get_fwd_configs(AUTOTUNE) @triton.jit @@ -195,14 +172,18 @@ def _attn_fwd_no_mask( stride_vk, stride_bn, stride_sn, + stride_sm, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, - philox_base_ptrs, - sd_mask_base_ptrs, - dropout_mask_base_ptrs, + philox_offset_base, + sd_mask, + stride_sz, + stride_sh, + off_z, + off_h_q, offs_m, offs_n, offs_d_qk, @@ -242,28 +223,20 @@ def _attn_fwd_no_mask( v_ptrs = v_base_ptrs + start_n * stride_vk kv_offs_n = start_n + tl.arange(0, BLOCK_N) + # Load K if PADDED_HEAD_QK: - k_mask, k_mask_other = (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK), 0.0 - else: - k_mask, k_mask_other = None, None - - if PADDED_HEAD_V: - v_mask, v_mask_other = (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V), 0.0 + k_mask = offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK + k = tl.load(k_ptrs, mask=k_mask, other=0.0) else: - v_mask, v_mask_other = None, None + k = tl.load(k_ptrs) - # load k and if preload_v then v - k = ( - tl.load(k_ptrs, mask=k_mask, other=k_mask_other) - if PADDED_HEAD_QK - else tl.load(k_ptrs) - ) + # Optionally preload V if PRE_LOAD_V: - v = ( - tl.load(v_ptrs, mask=v_mask, other=v_mask_other) - if PADDED_HEAD_V - else tl.load(v_ptrs) - ) + if PADDED_HEAD_V: + v_mask = offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + else: + v = tl.load(v_ptrs) # setup qk accumlator qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) @@ -309,29 +282,34 @@ def _attn_fwd_no_mask( # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - dropout_mask_ptrs = dropout_mask_base_ptrs + start_n * stride_sn - sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn - philox_ptrs = philox_base_ptrs + start_n * stride_sn - if tl_DROPOUT_USE_PYTORCH: - dropout_mask = tl.load(dropout_mask_ptrs, mask=qk_mask) - else: - rng_output = tl.rand( - philox_seed, philox_ptrs - ) # TODO: use tl.randint for better performance - dropout_mask = rng_output > dropout_p - if tl_DROPOUT_DUMP: - tl.store(dropout_mask_ptrs, dropout_mask, mask=qk_mask) - - # return scores with negative values for dropped vals - sd_mask = tl.where(dropout_mask, p, -p) - tl.store(sd_mask_ptrs, sd_mask, mask=qk_mask) + # Compute pointers for this block + philox_base = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh + philox_ptrs = philox_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + # compute dropout mask + rng_output = tl.rand(philox_seed, philox_ptrs) + dropout_mask = rng_output > dropout_p + + # return scores with negative values for dropped vals (only if RETURN_SCORES is True) + if RETURN_SCORES: + sd_mask_value = tl.where(dropout_mask, p, -p) + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + # Compute mask for sd_mask storage + sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + tl.store(sd_mask_ptrs, sd_mask_value, mask=sd_store_mask) # apply dropout mask in place p = tl.where(dropout_mask, p, 0.0) elif RETURN_SCORES: # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn - tl.store(sd_mask_ptrs, p, mask=qk_mask) + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + # Compute mask for sd_mask storage + sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + tl.store(sd_mask_ptrs, p, mask=sd_store_mask) # -- update output accumulator -- # alpha is an adjustment factor for acc and li as we loop and find new maxes @@ -343,11 +321,11 @@ def _attn_fwd_no_mask( alpha = tl.math.exp(m_diff) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = ( - tl.load(v_ptrs, mask=v_mask, other=v_mask_other) - if PADDED_HEAD_V - else tl.load(v_ptrs) - ) + if PADDED_HEAD_V: + v_mask = offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + else: + v = tl.load(v_ptrs) # -- update m_i and l_i l_i = l_i * alpha + l_ij @@ -382,14 +360,18 @@ def _attn_fwd_mask( stride_vk, stride_bn, stride_sn, + stride_sm, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, - philox_base_ptrs, - sd_mask_base_ptrs, - dropout_mask_base_ptrs, + philox_offset_base, + sd_mask, + stride_sz, + stride_sh, + off_z, + off_h_q, offs_m, offs_n, offs_d_qk, @@ -611,29 +593,74 @@ def _attn_fwd_mask( # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - dropout_mask_ptrs = dropout_mask_base_ptrs + start_n * stride_sn - sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn - philox_ptrs = philox_base_ptrs + start_n * stride_sn - if tl_DROPOUT_USE_PYTORCH: - dropout_mask = tl.load(dropout_mask_ptrs, mask=qk_mask) - else: - rng_output = tl.rand( - philox_seed, philox_ptrs - ) # TODO: use tl.randint for better performance - dropout_mask = rng_output > dropout_p - if tl_DROPOUT_DUMP: - tl.store(dropout_mask_ptrs, dropout_mask, mask=qk_mask) - - # return scores with negative values for dropped vals - sd_mask = tl.where(dropout_mask, p, -p) - tl.store(sd_mask_ptrs, sd_mask, mask=qk_mask) + # Compute pointers for this block + philox_base = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh + philox_ptrs = philox_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + # compute dropout mask + rng_output = tl.rand(philox_seed, philox_ptrs) + dropout_mask = rng_output > dropout_p + + # return scores with negative values for dropped vals (only if RETURN_SCORES is True) + if RETURN_SCORES: + sd_mask_value = tl.where(dropout_mask, p, -p) + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + # Compute mask for sd_mask storage - include bounds check + sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + + # Add causal mask if applicable to prevent writing to invalid positions + if IS_CAUSAL: + seqlen_delta_qk = seqlen_k - seqlen_q + causal_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk) + sd_store_mask = sd_store_mask & causal_constraint + + # Add sliding window mask if applicable + if USE_SLIDING_WINDOW: + seqlen_delta_qk = seqlen_k - seqlen_q + if WINDOW_SIZE_LEFT < 0: + # Only right window constraint + window_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT) + else: + # Both left and right window constraints + left_bound = offs_m[:, None] + seqlen_delta_qk - WINDOW_SIZE_LEFT + right_bound = offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT + window_constraint = (kv_offs_n[None, :] >= left_bound) & (kv_offs_n[None, :] <= right_bound) + sd_store_mask = sd_store_mask & window_constraint + + tl.store(sd_mask_ptrs, sd_mask_value, mask=sd_store_mask) # apply dropout mask in place p = tl.where(dropout_mask, p, 0.0) elif RETURN_SCORES: # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn - tl.store(sd_mask_ptrs, p, mask=qk_mask) + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + # Compute mask for sd_mask storage - include bounds check + sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + + # Add causal mask if applicable + if IS_CAUSAL: + seqlen_delta_qk = seqlen_k - seqlen_q + causal_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk) + sd_store_mask = sd_store_mask & causal_constraint + + # Add sliding window mask if applicable + if USE_SLIDING_WINDOW: + seqlen_delta_qk = seqlen_k - seqlen_q + if WINDOW_SIZE_LEFT < 0: + # Only right window constraint + window_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT) + else: + # Both left and right window constraints + left_bound = offs_m[:, None] + seqlen_delta_qk - WINDOW_SIZE_LEFT + right_bound = offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT + window_constraint = (kv_offs_n[None, :] >= left_bound) & (kv_offs_n[None, :] <= right_bound) + sd_store_mask = sd_store_mask & window_constraint + + tl.store(sd_mask_ptrs, p, mask=sd_store_mask) # -- update output accumulator -- # alpha is an adjustment factor for acc and li as we loop and find new maxes @@ -974,9 +1001,10 @@ def attn_fwd( stride_q_descale_z, stride_k_descale_z, stride_v_descale_z, - SM_SCALE: tl.constexpr, LSE, Out, + SD_MASK, + ALIBI_SLOPES, stride_qz, stride_qh, stride_qm, @@ -1013,9 +1041,6 @@ def attn_fwd( dropout_p, philox_seed, philox_offset_base, - sd_mask, - dropout_mask, - alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, @@ -1023,6 +1048,7 @@ def attn_fwd( MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, + SM_SCALE: tl.constexpr, IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, @@ -1035,7 +1061,6 @@ def attn_fwd( USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, - NEEDS_SDMASK: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr, @@ -1215,41 +1240,10 @@ def attn_fwd( if USE_ALIBI: a_offset = off_z * stride_az + off_h_q * stride_ah - alibi_slope = tl.load(alibi_slopes + a_offset) + alibi_slope = tl.load(ALIBI_SLOPES + a_offset) else: alibi_slope = None - if NEEDS_SDMASK: - sd_mask_offset = ( - sd_mask + off_z * stride_sz + off_h_q * stride_sh - ) # + cu_seqlens_q_start * stride_sm - sd_mask_ptrs = ( - sd_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn - ) - else: - sd_mask_ptrs = None - - if ENABLE_DROPOUT: - dropout_mask_offset = ( - dropout_mask + off_z * stride_sz + off_h_q * stride_sh - ) # + cu_seqlens_q_start * stride_sm - dropout_mask_ptrs = ( - dropout_mask_offset - + offs_m[:, None] * stride_sm - + offs_n[None, :] * stride_sn - ) - batch_philox_offset = ( - philox_offset_base + off_z * stride_sz + off_h_q * stride_sh - ) # + cu_seqlens_q_start * stride_sm - philox_ptrs = ( - batch_philox_offset - + offs_m[:, None] * stride_sm - + offs_n[None, :] * stride_sn - ) - else: - dropout_mask_ptrs = None - philox_ptrs = 0 - # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=ACCUMULATOR_TYPE) l_i = tl.full([BLOCK_M], 1.0, dtype=ACCUMULATOR_TYPE) @@ -1279,14 +1273,18 @@ def attn_fwd( stride_vk, stride_bn, stride_sn, + stride_sm, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, - philox_ptrs, - sd_mask_ptrs, - dropout_mask_ptrs, + philox_offset_base, + SD_MASK, + stride_sz, + stride_sh, + off_z, + off_h_q, offs_m, offs_n, offs_d_qk, @@ -1341,14 +1339,18 @@ def attn_fwd( stride_vk, stride_bn, stride_sn, + stride_sm, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, - philox_ptrs, - sd_mask_ptrs, - dropout_mask_ptrs, + philox_offset_base, + SD_MASK, + stride_sz, + stride_sh, + off_z, + off_h_q, offs_m, offs_n, offs_d_qk, @@ -1403,14 +1405,18 @@ def attn_fwd( stride_vk, stride_bn, stride_sn, + stride_sm, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, - philox_ptrs, - sd_mask_ptrs, - dropout_mask_ptrs, + philox_offset_base, + SD_MASK, + stride_sz, + stride_sh, + off_z, + off_h_q, offs_m, offs_n, offs_d_qk, @@ -1572,6 +1578,8 @@ def attention_forward_prefill_triton_impl( k: torch.Tensor, v: torch.Tensor, o: torch.Tensor, + softmax_lse: torch.Tensor, + sd_mask: Optional[torch.Tensor], sm_scale: float, alibi_slopes: Optional[torch.Tensor], causal: bool, @@ -1589,7 +1597,7 @@ def attention_forward_prefill_triton_impl( philox_seed: Optional[int], philox_offset: Optional[int], # misc - return_softmax: bool, + return_scores: bool, use_exp2: bool, # fp8 q_descale: Optional[torch.Tensor], @@ -1679,10 +1687,19 @@ def attention_forward_prefill_triton_impl( batch = len(cu_seqlens_q) - 1 head_size_qk = head_size_q - # softmax_lse shape - softmax_lse = torch.zeros( - (nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32 - ) + # Assert softmax_lse tensor is large enough + assert ( + softmax_lse.shape[0] >= nheads_q + ), f"softmax_lse.shape[0]={softmax_lse.shape[0]} must be >= nheads_q={nheads_q}" + assert ( + softmax_lse.shape[1] >= total_seqlen_q + ), f"softmax_lse.shape[1]={softmax_lse.shape[1]} must be >= total_seqlen_q={total_seqlen_q}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert ( + softmax_lse.device == q.device + ), f"softmax_lse must be on same device as q" # strides stride_qb, stride_qh, stride_qm, stride_qd = ( @@ -1755,10 +1772,22 @@ def attention_forward_prefill_triton_impl( max_seqlens_q = seqlen_q max_seqlens_k = seqlen_k - # softmax_lse shape - softmax_lse = torch.zeros( - (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 - ) + # Assert softmax_lse tensor is large enough + assert ( + softmax_lse.shape[0] >= batch + ), f"softmax_lse.shape[0]={softmax_lse.shape[0]} must be >= batch={batch}" + assert ( + softmax_lse.shape[1] >= nheads_q + ), f"softmax_lse.shape[1]={softmax_lse.shape[1]} must be >= nheads_q={nheads_q}" + assert ( + softmax_lse.shape[2] >= seqlen_q + ), f"softmax_lse.shape[2]={softmax_lse.shape[2]} must be >= seqlen_q={seqlen_q}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert ( + softmax_lse.device == q.device + ), f"softmax_lse must be on same device as q" # strides stride_qb, stride_qh, stride_qm, stride_qd = ( @@ -1807,16 +1836,18 @@ def attention_forward_prefill_triton_impl( ) # fp8 setup and assertions - IS_FP8 = is_fp8(q) + IS_FP8 = is_fp8([q, k, v]) if IS_FP8: - # we already asserted that q, k, v all have the same dtype, so no need to check each one - FP8_MAX = torch.finfo(q.dtype).max + rec_dtype = get_recommended_fp8_dtype(q) + if q.dtype != rec_dtype or k.dtype != rec_dtype or v.dtype != rec_dtype: + arch = get_arch() + warnings.warn( + f"Use {rec_dtype} data type on {arch}. Got q: {q.dtype}, k: {k.dtype}, v: {v.dtype}", + UserWarning, + ) - # Check and create default descale tensors if not provided if (q_descale is None) or (k_descale is None) or (v_descale is None): - import warnings - warnings.warn( "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning, @@ -1834,6 +1865,23 @@ def attention_forward_prefill_triton_impl( v_descale = torch.ones( batch, nheads_k, dtype=torch.float32, device=q.device ) + else: + # Enforce exact expected shapes; no reshaping or normalization. + assert ( + q_descale.dim() == 2 + and q_descale.shape[0] == batch + and q_descale.shape[1] == nheads_k + ), f"q_descale expected shape ({batch}, {nheads_k}) got {tuple(q_descale.shape)}" + assert ( + k_descale.dim() == 2 + and k_descale.shape[0] == batch + and k_descale.shape[1] == nheads_k + ), f"k_descale expected shape ({batch}, {nheads_k}) got {tuple(k_descale.shape)}" + assert ( + v_descale.dim() == 2 + and v_descale.shape[0] == batch + and v_descale.shape[1] == nheads_k + ), f"v_descale expected shape ({batch}, {nheads_k}) got {tuple(v_descale.shape)}" # o should be fp32 or fp16/bf16 assert o.dtype in [ @@ -1875,29 +1923,27 @@ def attention_forward_prefill_triton_impl( padded_d_model_qk = max(padded_d_model_qk, 16) padded_d_model_v = max(padded_d_model_v, 16) - # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out - # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according - # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing - # only. This return holds no useful output aside from debugging. - NEEDS_SDMASK = (dropout_p > 0.0) or return_softmax - if NEEDS_SDMASK: - sd_mask = torch.zeros( - (batch, nheads_q, max_seqlens_q, max_seqlens_k), - device=q.device, - dtype=torch.float32, - ) - if DROPOUT_USE_PYTORCH: - dropout_mask = create_dropout_mask( - dropout_p, - (batch, nheads_q, max_seqlens_q, max_seqlens_k), - seed=philox_seed, - ) - else: - dropout_mask = torch.zeros( - (batch, nheads_q, max_seqlens_q, max_seqlens_k), - device=q.device, - dtype=torch.float32, - ) + # sd_mask assertions and strides + if sd_mask is not None: + assert dropout_p > 0.0 or return_scores, "sd_mask provided but not used" + assert ( + sd_mask is not None + ), "sd_mask must be provided when return_scores=True or dropout_p > 0" + # Assert sd_mask tensor is large enough + assert ( + sd_mask.shape[0] >= batch + ), f"sd_mask.shape[0]={sd_mask.shape[0]} must be >= batch={batch}" + assert ( + sd_mask.shape[1] >= nheads_q + ), f"sd_mask.shape[1]={sd_mask.shape[1]} must be >= nheads_q={nheads_q}" + assert ( + sd_mask.shape[2] >= max_seqlens_q + ), f"sd_mask.shape[2]={sd_mask.shape[2]} must be >= max_seqlens_q={max_seqlens_q}" + assert ( + sd_mask.shape[3] >= max_seqlens_k + ), f"sd_mask.shape[3]={sd_mask.shape[3]} must be >= max_seqlens_k={max_seqlens_k}" + assert sd_mask.device == q.device, f"sd_mask must be on same device as q" + stride_sz, stride_sh, stride_sm, stride_sn = ( sd_mask.stride(0), sd_mask.stride(1), @@ -1905,8 +1951,6 @@ def attention_forward_prefill_triton_impl( sd_mask.stride(3), ) else: - sd_mask = None - dropout_mask = None stride_sz, stride_sh, stride_sm, stride_sn = (0, 0, 0, 0) if bias is not None: @@ -1932,9 +1976,10 @@ def attention_forward_prefill_triton_impl( stride_q_descale_z, stride_k_descale_z, stride_v_descale_z, - sm_scale, softmax_lse, o, + sd_mask, + alibi_slopes, stride_qb, stride_qh, stride_qm, @@ -1971,15 +2016,13 @@ def attention_forward_prefill_triton_impl( dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, - sd_mask=sd_mask, - dropout_mask=dropout_mask, - alibi_slopes=alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL_QK=head_size_qk, ACTUAL_BLOCK_DMODEL_V=head_size_v, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, + SM_SCALE=sm_scale, IS_CAUSAL=causal, USE_SLIDING_WINDOW=use_sliding_window, WINDOW_SIZE_LEFT=window_size_left, @@ -1991,12 +2034,9 @@ def attention_forward_prefill_triton_impl( USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p > 0.0, USE_EXP2=use_exp2, - RETURN_SCORES=return_softmax, - NEEDS_SDMASK=NEEDS_SDMASK, + RETURN_SCORES=return_scores, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_P_DESCALE=False, USE_SEQUSED=(seqused_q is not None or seqused_k is not None), - ) # Add flag for seqused - - return softmax_lse, sd_mask if return_softmax else None + ) diff --git a/flash_attn/flash_attn_triton_amd/interface_v2.py b/flash_attn/flash_attn_triton_amd/interface_v2.py index 134c4a76c12..d303ba63e7a 100644 --- a/flash_attn/flash_attn_triton_amd/interface_v2.py +++ b/flash_attn/flash_attn_triton_amd/interface_v2.py @@ -4,7 +4,15 @@ from .fwd_prefill import attention_forward_prefill_triton_impl from .fwd_decode import attention_forward_decode_triton_impl from .bwd import attention_backward_triton_impl -from .utils import DEBUG, USE_EXP2, BWD_MODE, PHILOX_SEED, PHILOX_OFFSET +from .utils import ( + DEBUG, + USE_EXP2, + BWD_MODE, + PHILOX_SEED, + PHILOX_OFFSET, + SHAPE_EXPECTATIONS, + round_multiple, +) def fwd( @@ -38,10 +46,10 @@ def fwd( if DEBUG: print() print("flash_attn_triton_amd.py::fwd inputs") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("out:", out, out.shape if out is not None else None) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out.shape if out is not None else None) print("alibi_slopes:", alibi_slopes) print("dropout_p:", dropout_p) print("softmax_scale:", softmax_scale) @@ -50,7 +58,11 @@ def fwd( print("window_size_right:", window_size_right) print("softcap:", softcap) print("return_softmax:", return_softmax) - out = torch.zeros_like(q) if out is None else out.zero_() + + if out is None: + out = torch.zeros_like(q) + else: + out.zero_() # Layout / shapes layout = "bshd" @@ -77,14 +89,51 @@ def fwd( nheads_k = k.shape[2] assert (nheads_q % nheads_k) == 0 + # Create output tensors based on shape expectations + if SHAPE_EXPECTATIONS == "rounded": + softmax_lse = torch.zeros( + (batch, nheads_q, round_multiple(max_seqlen_q, 128)), + device=q.device, + dtype=torch.float32, + ) + if dropout_p > 0.0 or return_softmax: + sd_mask = torch.zeros( + ( + batch, + nheads_q, + round_multiple(max_seqlen_q, 128), + round_multiple(max_seqlen_k, 128), + ), + device=q.device, + dtype=torch.float32, + ) + else: + sd_mask = None + else: + softmax_lse = torch.zeros( + (batch, nheads_q, max_seqlen_q), + device=q.device, + dtype=torch.float32, + ) + if dropout_p > 0.0 or return_softmax: + sd_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32, + ) + else: + sd_mask = None + # call implementation if DEBUG: print("Using Triton implementation") - softmax_lse, sd_mask = attention_forward_prefill_triton_impl( + attention_forward_prefill_triton_impl( q, k, v, out, + softmax_lse, + sd_mask, softmax_scale, alibi_slopes, causal, @@ -104,35 +153,53 @@ def fwd( None, None, None, + None, + None, + None, + None, ) if DEBUG: print("flash_attn_triton_amd.py::fwd outputs") - print("o:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None) + print("o:", out.shape if out is not None else None) + print("softmax_lse:", softmax_lse.shape if softmax_lse is not None else None) + print("sd_mask:", sd_mask.shape if sd_mask is not None else None) print("rng_state:", rng_state) # --- Assertions (shape + dtype contracts) --- # out: (B, Sq, Hq, D) assert out.shape == q.shape, f"[fwd] out shape {out.shape} != q shape {q.shape}" - # softmax_lse: (B, Hq, Sq) - expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) - assert ( - softmax_lse.shape == expected_lse_shape - ), f"[fwd] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + # softmax_lse dtype assert ( softmax_lse.dtype == torch.float32 ), f"[fwd] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + # softmax_lse shape depends on SHAPE_EXPECTATIONS + if SHAPE_EXPECTATIONS == "rounded": + expected_lse_shape = (q.shape[0], q.shape[2], round_multiple(q.shape[1], 128)) + else: + expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[fwd] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" if return_softmax: # sd_mask: (B, Hq, Sq, Sk) assert sd_mask is not None, "[fwd] return_softmax=True but sd_mask is None" assert sd_mask.dim() == 4, f"[fwd] sd_mask dim {sd_mask.dim()} != 4" - assert ( - sd_mask.shape[0] == q.shape[0] - and sd_mask.shape[1] == q.shape[2] - and sd_mask.shape[2] == q.shape[1] - ), f"[fwd] sd_mask leading dims {sd_mask.shape[:3]} mismatch (B,Hq,Sq) {(q.shape[0], q.shape[2], q.shape[1])}" + if SHAPE_EXPECTATIONS == "rounded": + expected_sq = round_multiple(q.shape[1], 128) + expected_sk = round_multiple(k.shape[1], 128) + assert ( + sd_mask.shape[0] == q.shape[0] + and sd_mask.shape[1] == q.shape[2] + and sd_mask.shape[2] == expected_sq + and sd_mask.shape[3] == expected_sk + ), f"[fwd] sd_mask shape {sd_mask.shape} != (B={q.shape[0]}, Hq={q.shape[2]}, Sq={expected_sq}, Sk={expected_sk})" + else: + assert ( + sd_mask.shape[0] == q.shape[0] + and sd_mask.shape[1] == q.shape[2] + and sd_mask.shape[2] == q.shape[1] + ), f"[fwd] sd_mask leading dims {sd_mask.shape[:3]} mismatch (B,Hq,Sq) {(q.shape[0], q.shape[2], q.shape[1])}" else: assert sd_mask is None, "[fwd] return_softmax=False but sd_mask is not None" @@ -169,14 +236,14 @@ def bwd( print() print("flash_attn_triton_amd.py::bwd inputs") print("dout:", dout, dout.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape if dq is not None else None) - print("dk:", dk, dk.shape if dk is not None else None) - print("dv:", dv, dv.shape if dv is not None else None) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out.shape) + print("softmax_lse:", softmax_lse.shape) + print("dq:", dq.shape if dq is not None else None) + print("dk:", dk.shape if dk is not None else None) + print("dv:", dv.shape if dv is not None else None) print("alibi_slopes:", alibi_slopes) print("dropout_p:", dropout_p) print("out:", out) @@ -193,7 +260,20 @@ def bwd( dv = torch.zeros_like(v) if dv is None else dv.zero_() # get shape - batch, _, nheads_q, _ = q.shape + batch, seqlen_q, nheads_q, _ = q.shape + + # Create delta tensor with shape based on expectations + # delta (softmax_d) : (B, Hq, Sq) or (B, Hq, round_multiple(Sq, 128)) + if SHAPE_EXPECTATIONS == "rounded": + delta = torch.zeros( + (batch, nheads_q, round_multiple(seqlen_q, 128)), + device=q.device, + dtype=torch.float32, + ) + else: + delta = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) # Upstream change: base seeding logic on provided rng_state instead of dropout probability. if rng_state is not None: @@ -211,8 +291,8 @@ def bwd( # call implementation if DEBUG: - print("Using Triton implementation") - delta = attention_backward_triton_impl( + print(f"Using Triton implementation in {BWD_MODE} mode") + attention_backward_triton_impl( do=dout, q=q, k=k, @@ -222,13 +302,14 @@ def bwd( dq=dq, dk=dk, dv=dv, + delta=delta, sm_scale=softmax_scale, alibi_slopes=alibi_slopes, causal=causal, layout="bshd", cu_seqlens_q=None, cu_seqlens_k=None, - max_seqlen_q=q.shape[1], + max_seqlen_q=seqlen_q, max_seqlen_k=k.shape[1], seqused_q=None, seqused_k=None, @@ -249,7 +330,10 @@ def bwd( assert dk.shape == k.shape, f"[bwd] dk shape {dk.shape} != k shape {k.shape}" assert dv.shape == v.shape, f"[bwd] dv shape {dv.shape} != v shape {v.shape}" # delta (softmax_d) : (B, Hq, Sq) - expected_delta_shape = (q.shape[0], q.shape[2], q.shape[1]) + if SHAPE_EXPECTATIONS == "rounded": + expected_delta_shape = (q.shape[0], q.shape[2], round_multiple(q.shape[1], 128)) + else: + expected_delta_shape = (q.shape[0], q.shape[2], q.shape[1]) assert ( delta.shape == expected_delta_shape ), f"[bwd] delta shape {delta.shape} != {expected_delta_shape}" @@ -305,9 +389,9 @@ def varlen_fwd( if DEBUG: print() print("flash_attn_triton_amd.py::varlen_fwd") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) print("alibi_slopes:", alibi_slopes) @@ -324,7 +408,33 @@ def varlen_fwd( # Layout and basic info for varlen layout = "thd" batch = len(cu_seqlens_q) - 1 - _, nheads_q, _ = q.shape + total_q, nheads_q, _ = q.shape + + # Create softmax_lse tensor - varlen always uses exact shape (Hq, Total_Q) + softmax_lse = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + + # Create sd_mask tensor if needed + if return_softmax: + # sd_mask: (B, Hq, Sq, Sk) - shape based on expectations + if SHAPE_EXPECTATIONS == "rounded": + sd_mask = torch.zeros( + ( + batch, + nheads_q, + round_multiple(max_seqlen_q, 128), + round_multiple(max_seqlen_k, 128), + ), + device=q.device, + dtype=q.dtype, + ) + else: + sd_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=q.dtype, + ) + else: + sd_mask = None if alibi_slopes is not None: if alibi_slopes.dim() == 1: @@ -346,11 +456,13 @@ def varlen_fwd( # call implementation if DEBUG: print("Using Triton implementation") - softmax_lse, sd_mask = attention_forward_prefill_triton_impl( + attention_forward_prefill_triton_impl( q, k, v, out, + softmax_lse, + sd_mask, softmax_scale, alibi_slopes, causal, @@ -396,12 +508,23 @@ def varlen_fwd( sd_mask is not None ), "[varlen_fwd] return_softmax=True but sd_mask is None" assert sd_mask.dim() == 4, f"[varlen_fwd] sd_mask dim {sd_mask.dim()} != 4" - assert sd_mask.shape[0] == ( - len(cu_seqlens_q) - 1 - ), f"[varlen_fwd] sd_mask batch {sd_mask.shape[0]} != {len(cu_seqlens_q)-1}" + batch = len(cu_seqlens_q) - 1 + assert ( + sd_mask.shape[0] == batch + ), f"[varlen_fwd] sd_mask batch {sd_mask.shape[0]} != {batch}" assert ( sd_mask.shape[1] == q.shape[1] ), f"[varlen_fwd] sd_mask nheads {sd_mask.shape[1]} != {q.shape[1]}" + if SHAPE_EXPECTATIONS == "rounded": + expected_sq = round_multiple(max_seqlen_q, 128) + expected_sk = round_multiple(max_seqlen_k, 128) + assert ( + sd_mask.shape[2] == expected_sq and sd_mask.shape[3] == expected_sk + ), f"[varlen_fwd] sd_mask shape {sd_mask.shape} != (B={batch}, Hq={q.shape[1]}, Sq={expected_sq}, Sk={expected_sk})" + else: + assert ( + sd_mask.shape[2] == max_seqlen_q and sd_mask.shape[3] == max_seqlen_k + ), f"[varlen_fwd] sd_mask shape {sd_mask.shape} != (B={batch}, Hq={q.shape[1]}, Sq={max_seqlen_q}, Sk={max_seqlen_k})" else: assert ( sd_mask is None @@ -447,15 +570,15 @@ def varlen_bwd( if DEBUG: print() print("varlen_bwd") - print("dout:", dout, dout.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) + print("dout:", dout.shape) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) print("out:", out) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape if dq is not None else None) - print("dk:", dk, dk.shape if dk is not None else None) - print("dv:", dv, dv.shape if dv is not None else None) + print("softmax_lse:", softmax_lse.shape) + print("dq:", dq.shape if dq is not None else None) + print("dk:", dk.shape if dk is not None else None) + print("dv:", dv.shape if dv is not None else None) print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) print("alibi_slopes:", alibi_slopes) @@ -476,7 +599,16 @@ def varlen_bwd( # get shape batch = len(cu_seqlens_q) - 1 - _, nheads_q, _ = q.shape + total_q, nheads_q, _ = q.shape + + # Create delta tensor with shape based on expectations + # delta (softmax_d) : (Hq, Total_Q) or (Hq, Total_Q + 128*batch) + if SHAPE_EXPECTATIONS == "rounded": + delta = torch.zeros( + (nheads_q, total_q + 128 * batch), device=q.device, dtype=torch.float32 + ) + else: + delta = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) # Upstream change: base seeding logic on provided rng_state instead of dropout probability. if rng_state is not None: @@ -494,8 +626,8 @@ def varlen_bwd( # call implementation if DEBUG: - print("Using Triton implementation") - delta = attention_backward_triton_impl( + print(f"Using Triton implementation in {BWD_MODE} mode") + attention_backward_triton_impl( do=dout, q=q, k=k, @@ -505,6 +637,7 @@ def varlen_bwd( dq=dq, dk=dk, dv=dv, + delta=delta, sm_scale=softmax_scale, alibi_slopes=alibi_slopes, causal=causal, @@ -532,7 +665,11 @@ def varlen_bwd( assert dq.shape == q.shape, f"[varlen_bwd] dq shape {dq.shape} != q shape {q.shape}" assert dk.shape == k.shape, f"[varlen_bwd] dk shape {dk.shape} != k shape {k.shape}" assert dv.shape == v.shape, f"[varlen_bwd] dv shape {dv.shape} != v shape {v.shape}" - expected_delta_shape = (q.shape[1], q.shape[0]) # (Hq, Total_Q) + if SHAPE_EXPECTATIONS == "rounded": + batch = len(cu_seqlens_q) - 1 + expected_delta_shape = (q.shape[1], q.shape[0] + 128 * batch) + else: + expected_delta_shape = (q.shape[1], q.shape[0]) # (Hq, Total_Q) assert ( delta.shape == expected_delta_shape ), f"[varlen_bwd] delta shape {delta.shape} != {expected_delta_shape}" @@ -622,7 +759,12 @@ def fwd_kvcache( v_new = v # get shape - batch, _, nheads_q, _ = q.shape + batch, seqlen_q, nheads_q, _ = q.shape + + # Create softmax_lse tensor - decode always uses exact shape (B, Hq, Sq) + softmax_lse = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) if alibi_slopes is not None: if alibi_slopes.dim() == 1: @@ -633,13 +775,14 @@ def fwd_kvcache( # launch kernel if DEBUG: print("Using Triton implementation") - softmax_lse = attention_forward_decode_triton_impl( + attention_forward_decode_triton_impl( q, k_cache, v_cache, k_new, v_new, out, + softmax_lse, softmax_scale, causal, window_left, diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index 436077a8a7c..3ed35de5cd1 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -1,10 +1,19 @@ -import torch import os +import warnings +import torch from typing import Optional, Union, Tuple from .fwd_prefill import attention_forward_prefill_triton_impl from .fwd_decode import attention_forward_decode_triton_impl from .bwd import attention_backward_triton_impl -from .utils import DEBUG, USE_EXP2, BWD_MODE, PHILOX_SEED, PHILOX_OFFSET, is_fp8 +from .utils import ( + DEBUG, + USE_EXP2, + BWD_MODE, + PHILOX_SEED, + PHILOX_OFFSET, + is_fp8, + get_recommended_fp8_dtype, +) def fwd( @@ -52,13 +61,29 @@ def fwd( if DEBUG: print() print("interface_fa_v3.py::fwd inputs") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("k_new:", k_new, k_new.shape if k_new is not None else None) - print("v_new:", v_new, v_new.shape if v_new is not None else None) - print("qv:", qv, qv.shape if qv is not None else None) - print("out:", out, out.shape if out is not None else None) + print("q:", q.dtype if q is not None else None, q.shape) + print("k:", k.dtype if k is not None else None, k.shape) + print("v:", v.dtype if v is not None else None, v.shape) + print( + "k_new:", + k_new.dtype if k_new is not None else None, + k_new.shape if k_new is not None else None, + ) + print( + "v_new:", + v_new.dtype if v_new is not None else None, + v_new.shape if v_new is not None else None, + ) + print( + "qv:", + qv.dtype if qv is not None else None, + qv.shape if qv is not None else None, + ) + print( + "out:", + out.dtype if out is not None else None, + out.shape if out is not None else None, + ) print( "cu_seqlens_q:", cu_seqlens_q, @@ -111,13 +136,19 @@ def fwd( seqlens_rotary.shape if seqlens_rotary is not None else None, ) print( - "q_descale:", q_descale, q_descale.shape if q_descale is not None else None + "q_descale:", + q_descale.dtype if q_descale is not None else None, + q_descale.shape if q_descale is not None else None, ) print( - "k_descale:", k_descale, k_descale.shape if k_descale is not None else None + "k_descale:", + k_descale.dtype if k_descale is not None else None, + k_descale.shape if k_descale is not None else None, ) print( - "v_descale:", v_descale, v_descale.shape if v_descale is not None else None + "v_descale:", + v_descale.dtype if v_descale is not None else None, + v_descale.shape if v_descale is not None else None, ) print("softmax_scale:", softmax_scale) print("causal:", causal) @@ -185,9 +216,6 @@ def fwd( "cu_seqlens_k_new is not yet supported in the AMD Triton backend" ) - # if seqlens_rotary is not None: - # raise NotImplementedError("seqlens_rotary is not yet supported in the AMD Triton backend") - # establish layout / varlen & max seq lens if cu_seqlens_q is not None: if len(q.shape) != 3: @@ -241,7 +269,8 @@ def fwd( ) if out is None: - out_dtype = torch.float32 if is_fp8(q) else q.dtype + # NOTE: Using types that are lower precision than float32 such as bfloat16 for fp8 causes mismatches on a small set of tests. + out_dtype = torch.float32 if is_fp8([q, k, v]) else q.dtype if layout == "bshd": out = torch.zeros( q.shape[0], @@ -262,45 +291,6 @@ def fwd( else: out = out.zero_() - if is_fp8(q): - if (q_descale is None) or (k_descale is None) or (v_descale is None): - import warnings - - warnings.warn( - "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", - UserWarning, - ) - else: - # Enforce exact expected shapes; no reshaping or normalization. - if layout == "bshd": - expected_batch = q.shape[0] - expected_q_heads = q.shape[2] - expected_kv_heads = k.shape[2] - else: # thd layout - expected_batch = ( - (len(cu_seqlens_q_local) - 1) - if cu_seqlens_q_local is not None - else 1 - ) - expected_q_heads = q.shape[1] - expected_kv_heads = k.shape[1] - - assert ( - q_descale.dim() == 2 - and q_descale.shape[0] == expected_batch - and q_descale.shape[1] == expected_kv_heads - ), f"q_descale expected shape ({expected_batch}, {expected_kv_heads}) got {tuple(q_descale.shape)}" - assert ( - k_descale.dim() == 2 - and k_descale.shape[0] == expected_batch - and k_descale.shape[1] == expected_kv_heads - ), f"k_descale expected shape ({expected_batch}, {expected_kv_heads}) got {tuple(k_descale.shape)}" - assert ( - v_descale.dim() == 2 - and v_descale.shape[0] == expected_batch - and v_descale.shape[1] == expected_kv_heads - ), f"v_descale expected shape ({expected_batch}, {expected_kv_heads}) got {tuple(v_descale.shape)}" - # Handle causal mask causal_flag = bool(causal) @@ -323,13 +313,20 @@ def fwd( f"Using Decode Triton implementation (cache_seqlens={seqused_k is not None}, k_new={k_new is not None}, v_new={v_new is not None}, kv_batch_idx={kv_batch_idx is not None})" ) - softmax_lse = attention_forward_decode_triton_impl( + # Create softmax_lse tensor for decode - always exact shape (B, Hq, Sq) + batch, seqlen_q, nheads_q, _ = q.shape + softmax_lse = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + + attention_forward_decode_triton_impl( q, k, v, k_new, v_new, out, + softmax_lse, softmax_scale, causal_flag, window_size_left, @@ -350,11 +347,31 @@ def fwd( else: if DEBUG: print("Using Prefill Triton implementation") - softmax_lse, _ = attention_forward_prefill_triton_impl( + + # Create softmax_lse tensor - FA3 always uses exact shapes + if layout == "thd": + # varlen: (Hq, Total_Q) + total_q, nheads_q, _ = q.shape + softmax_lse = torch.zeros( + (nheads_q, total_q), device=q.device, dtype=torch.float32 + ) + else: + # bshd: (B, Hq, Sq) + batch, seqlen_q, nheads_q, _ = q.shape + softmax_lse = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + + # sd_mask is not returned in v3 interface + sd_mask = None + + attention_forward_prefill_triton_impl( q, k, v, out, + softmax_lse, + sd_mask, softmax_scale, alibi_slopes, causal_flag, @@ -384,8 +401,59 @@ def fwd( if DEBUG: print("interface_fa_v3.py::fwd outputs") - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) + print( + "out:", + out.dtype if out is not None else None, + out.shape if out is not None else None, + ) + print( + "softmax_lse:", + softmax_lse.dtype if softmax_lse is not None else None, + softmax_lse.shape if softmax_lse is not None else None, + ) + + # --- Assertions (FA3 always expects exact shapes) --- + # out: same shape as q except last dim is v's head_dim + if layout == "thd": + # varlen: (Total_Q, Hq, Dv) + assert ( + out.shape[0] == q.shape[0] + ), f"[fwd_v3] out.shape[0] {out.shape[0]} != q.shape[0] {q.shape[0]}" + assert ( + out.shape[1] == q.shape[1] + ), f"[fwd_v3] out.shape[1] {out.shape[1]} != q.shape[1] {q.shape[1]}" + assert ( + out.shape[2] == v.shape[-1] + ), f"[fwd_v3] out.shape[2] {out.shape[2]} != v.shape[-1] {v.shape[-1]}" + else: + # bshd: (B, Sq, Hq, Dv) + assert ( + out.shape[0] == q.shape[0] + ), f"[fwd_v3] out.shape[0] {out.shape[0]} != q.shape[0] {q.shape[0]}" + assert ( + out.shape[1] == q.shape[1] + ), f"[fwd_v3] out.shape[1] {out.shape[1]} != q.shape[1] {q.shape[1]}" + assert ( + out.shape[2] == q.shape[2] + ), f"[fwd_v3] out.shape[2] {out.shape[2]} != q.shape[2] {q.shape[2]}" + assert ( + out.shape[3] == v.shape[-1] + ), f"[fwd_v3] out.shape[3] {out.shape[3]} != v.shape[-1] {v.shape[-1]}" + + # softmax_lse dtype + assert ( + softmax_lse.dtype == torch.float32 + ), f"[fwd_v3] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + # softmax_lse shape depends on layout + if layout == "thd": + # varlen: (Hq, Total_Q) + expected_lse_shape = (q.shape[1], q.shape[0]) + else: + # bshd: (B, Hq, Sq) + expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[fwd_v3] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" # Return format compatible with v3 # V3 returns (out, softmax_lse, *rest) where rest can be empty or contain additional outputs @@ -425,15 +493,45 @@ def bwd( if DEBUG: print() print("interface_fa_v3.py::bwd inputs") - print("dout:", dout, dout.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape if dq is not None else None) - print("dk:", dk, dk.shape if dk is not None else None) - print("dv:", dv, dv.shape if dv is not None else None) + print( + "dout:", + dout.dtype if dout is not None else None, + dout.shape if dout is not None else None, + ) + print( + "q:", q.dtype if q is not None else None, q.shape if q is not None else None + ) + print( + "k:", k.dtype if k is not None else None, k.shape if k is not None else None + ) + print( + "v:", v.dtype if v is not None else None, v.shape if v is not None else None + ) + print( + "out:", + out.dtype if out is not None else None, + out.shape if out is not None else None, + ) + print( + "softmax_lse:", + softmax_lse.dtype if softmax_lse is not None else None, + softmax_lse.shape if softmax_lse is not None else None, + ) + print( + "dq:", + dq.dtype if dq is not None else None, + dq.shape if dq is not None else None, + ) + print( + "dk:", + dk.dtype if dk is not None else None, + dk.shape if dk is not None else None, + ) + print( + "dv:", + dv.dtype if dv is not None else None, + dv.shape if dv is not None else None, + ) print( "cu_seqlens_q:", cu_seqlens_q, @@ -475,22 +573,30 @@ def bwd( ) # Initialize gradient tensors if not provided - dq = torch.zeros_like(q) if dq is None else dq.zero_() - dk = torch.zeros_like(k) if dk is None else dk.zero_() - dv = torch.zeros_like(v) if dv is None else dv.zero_() + # NOTE: Using types that are lower precision than float32 such as bfloat16 for fp8 causes mismatches on a small set of tests. + grad_dtype = torch.float32 if is_fp8([q, k, v]) else q.dtype + dq = torch.zeros_like(q, dtype=grad_dtype) if dq is None else dq.zero_() + dk = torch.zeros_like(k, dtype=grad_dtype) if dk is None else dk.zero_() + dv = torch.zeros_like(v, dtype=grad_dtype) if dv is None else dv.zero_() # Determine layout based on cu_seqlens if cu_seqlens_q is not None and cu_seqlens_k is not None: # Variable length sequence mode layout = "thd" batch = len(cu_seqlens_q) - 1 - _, nheads_q, _ = q.shape + total_q, nheads_q, _ = q.shape + # Create delta tensor - varlen: (Hq, Total_Q) + delta = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) else: # Regular batch mode layout = "bshd" - batch, _, nheads_q, _ = q.shape + batch, seqlen_q, nheads_q, _ = q.shape max_seqlen_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q max_seqlen_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k + # Create delta tensor - bshd: (B, Hq, Sq) + delta = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) # V3 backward doesn't have dropout or alibi slopes dropout_p = 0.0 @@ -499,8 +605,8 @@ def bwd( # Call implementation if DEBUG: - print("Using Triton implementation (unified backward dispatcher)") - delta = attention_backward_triton_impl( + print(f"Using Triton implementation in {BWD_MODE} mode") + attention_backward_triton_impl( do=dout, q=q, k=k, @@ -510,6 +616,7 @@ def bwd( dq=dq, dk=dk, dv=dv, + delta=delta, sm_scale=softmax_scale, alibi_slopes=alibi_slopes, causal=causal, @@ -529,10 +636,45 @@ def bwd( if DEBUG: print("interface_fa_v3.py::bwd outputs") - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) - print("delta:", delta, delta.shape if delta is not None else None) + print( + "dq:", + dq.dtype if dq is not None else None, + dq.shape if dq is not None else None, + ) + print( + "dk:", + dk.dtype if dk is not None else None, + dk.shape if dk is not None else None, + ) + print( + "dv:", + dv.dtype if dv is not None else None, + dv.shape if dv is not None else None, + ) + print( + "delta:", + delta.dtype if delta is not None else None, + delta.shape if delta is not None else None, + ) + + # --- Assertions (FA3 always expects exact shapes) --- + # Gradients should match input shapes + assert dq.shape == q.shape, f"[bwd_v3] dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"[bwd_v3] dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"[bwd_v3] dv shape {dv.shape} != v shape {v.shape}" + # delta (softmax_d) should match softmax_lse shape + assert ( + delta.dtype == torch.float32 + ), f"[bwd_v3] delta dtype {delta.dtype} != torch.float32" + if layout == "thd": + # varlen: (Hq, Total_Q) + expected_delta_shape = (q.shape[1], q.shape[0]) + else: + # bshd: (B, Hq, Sq) + expected_delta_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + delta.shape == expected_delta_shape + ), f"[bwd_v3] delta shape {delta.shape} != {expected_delta_shape}" # V3 expects (dq, dk, dv, softmax_d, *rest) # delta is the softmax_d in this case diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 71ed1c1c2de..d0a20eb6aa4 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -17,23 +17,13 @@ "true", "yes", ) -if AUTOTUNE: - os.environ["TRITON_PRINT_AUTOTUNING"] = "1" DEBUG = os.environ.get("FLASH_ATTENTION_TRITON_AMD_DEBUG", "0").lower() in ( "1", "true", "yes", ) -PERF = os.environ.get("FLASH_ATTENTION_TRITON_AMD_PERF", "0").lower() in ( - "1", - "true", - "yes", -) -USE_SINGLE_BWD_KERNEL = os.environ.get("USE_SINGLE_BWD_KERNEL", "0").lower() in ( - "1", - "true", - "yes", -) +if AUTOTUNE or DEBUG: + os.environ["TRITON_PRINT_AUTOTUNING"] = "1" USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" USE_TRITON_INTERPRET = os.environ.get("TRITON_INTERPRET", "0").lower() in ( "1", @@ -50,12 +40,11 @@ ) if USE_TRITON_ROCM: # TODO remove this random.seed(42) -BWD_MODE = os.environ.get("BWD_MODE", "fused_no_atomics").lower() -DROPOUT_USE_PYTORCH = False -DROPOUT_DUMP = False +BWD_MODE: Literal["fused", "fused_atomic", "split"] = "fused" USE_EXP2 = True PHILOX_SEED = 0x1BF58 PHILOX_OFFSET = 0x1D4B49 +SHAPE_EXPECTATIONS: Literal["exact", "rounded"] = "exact" # ------------------------------- @@ -555,299 +544,6 @@ def generate_varlen_kv_packed( return x, cu_seqlens, max_seqlen -def input_helper( - BATCH: int, - HQ: int, - HK: int, - N_CTX_Q: int, - N_CTX_K: int, - D_HEAD: int, - CAUSAL: bool, - DROPOUT_P: float, - dtype: torch.dtype, - layout: Literal["bshd", "bhsd", "thd"], - packing: Optional[Literal["kv", "qkv"]] = None, - device: Literal["cpu", "cuda"] = "cuda", -): - torch.manual_seed(20) - is_fp8_dtype = is_dtype_fp8(dtype) - - if layout == "thd": - # set params - TOTAL_SEQLENS_Q = BATCH * N_CTX_Q - TOTAL_SEQLENS_K = BATCH * N_CTX_K - equal_seqlens = False - - # deal with packing - if packing is None: - # gen tensors - if is_fp8_dtype: - q, cu_seqlens_q, max_seqlen_q, descale_q = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - k, cu_seqlens_k, max_seqlen_k, descale_k = generate_varlen_tensor( - TOTAL_SEQLENS_K, - HK, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - v, _, _, descale_v = generate_varlen_tensor( - TOTAL_SEQLENS_K, - HK, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - do, _, _, descale_do = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - else: - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor( - TOTAL_SEQLENS_K, - HK, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - v, _, _ = generate_varlen_tensor( - TOTAL_SEQLENS_K, - HK, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - do, _, _ = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - elif packing == "kv": - # gen tensors with kv packing - if is_fp8_dtype: - raise ValueError("FP8 not supported for KV packing yet") - else: - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - kv, cu_seqlens_k, max_seqlen_k = generate_varlen_kv_packed( - TOTAL_SEQLENS_K, - HK, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - do, _, _ = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - elif packing == "qkv": - # qkv packing - requires same sequence length for q and k - assert ( - N_CTX_Q == N_CTX_K - ), "For QKV packing, Q and K must have same sequence length" - assert HQ == HK, "For QKV packing, Q and K must have same number of heads" - - if is_fp8_dtype: - raise ValueError("FP8 not supported for QKV packing yet") - else: - qkv, cu_seqlens_q, max_seqlen_q = generate_varlen_qkv_packed( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - cu_seqlens_k = cu_seqlens_q - max_seqlen_k = max_seqlen_q - do, _, _ = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - - elif layout == "bshd" or layout == "bhsd": - # deal with packing - if packing is None: - # gen tensors - if layout == "bshd": - if is_fp8_dtype: - q, descale_q = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - k, descale_k = generate_bshd_tensor( - BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device - ) - v, descale_v = generate_bshd_tensor( - BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device - ) - do, descale_do = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - else: - q = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - k = generate_bshd_tensor( - BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device - ) - v = generate_bshd_tensor( - BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device - ) - do = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - elif layout == "bhsd": - q, descale_q = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - k, descale_k = generate_bhsd_tensor( - BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device - ) - v, descale_v = generate_bhsd_tensor( - BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device - ) - do, descale_do = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - else: - q = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - k = generate_bhsd_tensor( - BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device - ) - v = generate_bhsd_tensor( - BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device - ) - do = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - elif packing == "kv": - # gen tensors with kv packing - if is_fp8_dtype: - raise ValueError("FP8 not supported for KV packing yet") - else: - if layout == "bshd": - q = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - kv = generate_bshd_kv_packed( - BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device - ) - do = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - elif layout == "bhsd": - q = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - kv = generate_bhsd_kv_packed( - BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device - ) - do = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - elif packing == "qkv": - # qkv packing - requires same sequence length for q and k - assert ( - N_CTX_Q == N_CTX_K - ), "For QKV packing, Q and K must have same sequence length" - assert HQ == HK, "For QKV packing, Q and K must have same number of heads" - - if is_fp8_dtype: - raise ValueError("FP8 not supported for QKV packing yet") - else: - if layout == "bshd": - qkv = generate_bshd_qkv_packed( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - do = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - elif layout == "bhsd": - qkv = generate_bhsd_qkv_packed( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - do = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - - else: - raise ValueError(f"Unknown layout: {layout}") - - # return based on packing - if packing is None: - if is_fp8_dtype: - return (q, descale_q), (k, descale_k), (v, descale_v), (do, descale_do) - else: - return q, k, v, do - elif packing == "kv": - if is_fp8_dtype: - raise ValueError("FP8 not supported kv packing yet") - else: - return q, kv, do - elif packing == "qkv": - if is_fp8_dtype: - raise ValueError("FP8 not supported qkv packing yet") - else: - return qkv, do - else: - assert False, f"Unsupported packing mode: {packing}" - - # ------------------------------- # Alibi # ------------------------------- @@ -889,23 +585,72 @@ def compute_alibi_block( # ------------------------------- # FP8 # ------------------------------- -def is_dtype_fp8(dtype): - if dtype in { +def is_dtype_fp8(dtype) -> bool: + supported = { torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz, - }: - if arch_supports_fp8(): - return True - else: - raise RuntimeError("This device doesnot support fp8") - else: + } + if dtype not in supported: return False + return True + + +_RECOMMENDED_FP8_REPLACEMENTS = { + "gfx942": { + torch.float8_e4m3fn: torch.float8_e4m3fnuz, + torch.float8_e5m2: torch.float8_e5m2fnuz, + }, +} + + +def get_recommended_fp8_dtype(x): + dtype = x.dtype if isinstance(x, torch.Tensor) else x + if not is_dtype_fp8(dtype): + return dtype + arch = get_arch() + return _RECOMMENDED_FP8_REPLACEMENTS.get(arch, {}).get(dtype, dtype) -def is_fp8(x): - return is_dtype_fp8(x.dtype) +def is_fp8(x) -> bool: + """Return whether tensor(s) use FP8. + + Accepts either a single tensor or a list/tuple of tensors. + + Rules: + * Single tensor: return True if FP8 (after arch validation), else False. + * Multiple tensors: + - If all tensors are FP8 -> return True. + - If none are FP8 -> return False. + - If a mix of FP8 and non-FP8 -> raise ValueError. + + Empty list/tuple returns False. + """ + + def _is_fp8_single(t: torch.Tensor) -> bool: + if is_dtype_fp8(t.dtype): + arch = get_arch() + if arch not in ("gfx942", "gfx950"): + raise RuntimeError( + f"{arch} is not in the list of supported architectures for FP8" + ) + return True + return False + + if isinstance(x, (list, tuple)): + if len(x) == 0: + return False + flags = [_is_fp8_single(t) for t in x] + if all(flags): + return True + if not any(flags): + return False + raise ValueError( + "Mixed FP8 and non-FP8 tensors provided; either all or none must be FP8." + ) + else: + return _is_fp8_single(x) @triton.jit @@ -1514,9 +1259,7 @@ def _apply_rotary_kernel( batch, ) - # NOTE: We assume CUDA device indexing compatibility in upstream; adapt for ROCm by using device context. - # For ROCm, torch.cuda.device works if HIP_VISIBLE_DEVICES mapping is set. - with torch.cuda.device(x.device.index): # Works for ROCm as alias + with torch.cuda.device(x.device.index): torch.library.wrap_triton(_rotary_kernel)[grid]( out, x, @@ -1737,6 +1480,13 @@ def get_arch(): return triton.runtime.driver.active.get_current_target().arch +@functools.cache +def get_cu_count(): + return torch.cuda.get_device_properties( + torch.cuda.current_device() + ).multi_processor_count + + @functools.cache def is_cdna(): return is_hip() and get_arch() in ( @@ -1759,8 +1509,3 @@ def is_rdna(): "gfx1200", "gfx1201", ) - - -@functools.cache -def arch_supports_fp8(): - return is_hip() and get_arch() in ("gfx942") diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 4f54e44b95c..36d928f30f3 100755 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -3,15 +3,18 @@ from typing import Optional, Union, List, Tuple import os +import sys +from pathlib import Path import torch import torch.nn as nn USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" if USE_TRITON_ROCM: - import sys - sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - from flash_attn.flash_attn_triton_amd import flash_attn_3 as flash_attn_3_gpu + repo_root = Path(__file__).resolve().parent.parent + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + from flash_attn.flash_attn_triton_amd import flash_attn_3 as flash_attn_3_gpu # type: ignore else: # isort: off # We need to import the CUDA kernels after importing torch From f26c5baa79cfe2312e5e60c727b89e9865d2d554 Mon Sep 17 00:00:00 2001 From: Michael Date: Mon, 12 Jan 2026 22:13:08 +0000 Subject: [PATCH 16/27] update the to machine new changes --- hopper/flash_attn_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 36d928f30f3..92a014624e1 100755 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -101,7 +101,7 @@ def _flash_attn_forward( ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] seqlens_rotary = maybe_contiguous(seqlens_rotary) - out, softmax_lse, *rest = flash_attn_3_gpu.fwd( + out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_gpu.fwd( q, k, v, @@ -279,7 +279,7 @@ def _flash_attn_backward( ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, *rest = flash_attn_3_gpu.bwd( + softmax_d, *rest = flash_attn_3_gpu.bwd( dout, q, k, From c9bba2baada7b5484dacb4dc16efd4e725594474 Mon Sep 17 00:00:00 2001 From: Michael Date: Mon, 12 Jan 2026 17:16:23 -0500 Subject: [PATCH 17/27] save --- flash_attn/flash_attn_triton_amd/interface_v3.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index 3ed35de5cd1..4f391aa7a1b 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -51,7 +51,7 @@ def fwd( num_splits: int = 1, pack_gqa=None, sm_margin: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Flash Attention v3 forward pass compatible interface for AMD Triton implementation. @@ -456,8 +456,9 @@ def fwd( ), f"[fwd_v3] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" # Return format compatible with v3 - # V3 returns (out, softmax_lse, *rest) where rest can be empty or contain additional outputs - return out, softmax_lse + # V3 returns (out, softmax_lse, out_accum, softmax_lse_accum) + # out_accum and softmax_lse_accum are None for Triton AMD (no split-k accumulation) + return out, softmax_lse, None, None def bwd( From 080c08c66a43cd41cf01698eee2d34f81d6dfc51 Mon Sep 17 00:00:00 2001 From: Michael Date: Mon, 12 Jan 2026 22:19:05 +0000 Subject: [PATCH 18/27] fix more bugs --- flash_attn/flash_attn_triton_amd/interface_v3.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index 4f391aa7a1b..5e95d3c661d 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -484,7 +484,7 @@ def bwd( softcap: float, deterministic: bool, sm_margin: int = 0, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor]: """ Flash Attention v3 backward pass compatible interface for AMD Triton implementation. @@ -677,9 +677,9 @@ def bwd( delta.shape == expected_delta_shape ), f"[bwd_v3] delta shape {delta.shape} != {expected_delta_shape}" - # V3 expects (dq, dk, dv, softmax_d, *rest) + # V3 expects (softmax_d, *rest) # delta is the softmax_d in this case - return dq, dk, dv, delta + return delta def fwd_combine( From 4b62f51a3f38645afa76702d280b00d30cd06d0d Mon Sep 17 00:00:00 2001 From: Michael Date: Tue, 13 Jan 2026 11:55:21 -0500 Subject: [PATCH 19/27] remove random seed --- flash_attn/flash_attn_triton_amd/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index d0a20eb6aa4..8c628002d5f 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -38,8 +38,6 @@ os.environ.get("DEBUG_TRITON_DETAIL", "0").lower() in ("1", "true", "yes") and USE_TRITON_INTERPRET ) -if USE_TRITON_ROCM: # TODO remove this - random.seed(42) BWD_MODE: Literal["fused", "fused_atomic", "split"] = "fused" USE_EXP2 = True PHILOX_SEED = 0x1BF58 From 98b2e379d851207f12e9689aa3287394d58f3943 Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 14 Jan 2026 08:44:56 -0500 Subject: [PATCH 20/27] clean up --- .gitignore | 25 ++----------------------- README.md | 4 ++-- tests/test_flash_attn_triton_amd.py | 2 +- 3 files changed, 5 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index c4a123e4e0b..dc508654045 100644 --- a/.gitignore +++ b/.gitignore @@ -32,26 +32,5 @@ var/ # Dev venv -# AMD -scripts -csrc/flash_attn_ck -.eggs -log -*.rocprof* -*.log -core.* -gpucore.* -*.csv -*.png -*.html -*.json -*.txt -*.pth -*.md -*.crt -training/logs -training/data -# ck modules -csrc/composable_kernel -csrc/cutlass -.amd \ No newline at end of file +# compile-time generated file +flash_attn_config.py \ No newline at end of file diff --git a/README.md b/README.md index 85c05a9ab68..96ca5ba521d 100755 --- a/README.md +++ b/README.md @@ -154,7 +154,7 @@ To get started with the triton backend for AMD, follow the steps below. First install the torch for ROCm from https://pytorch.org/get-started/locally/ if it is not installed. The torch and triton will be installed. ``` -pip install triton==3.3.0 +pip install triton==3.4.0 ``` Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. @@ -181,7 +181,7 @@ FROM rocm/pytorch:latest WORKDIR /workspace # install triton -RUN pip install triton==3.3.0 +RUN pip install triton==3.4.0 # build flash attention with triton backend RUN git clone https://github.com/Dao-AILab/flash-attention &&\ diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index ba1932438e2..7f46b6d2813 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -16,7 +16,7 @@ from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb -from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_hip, is_rdna, generate_bshd_tensor, save_tensor_to_csv +from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_hip, is_rdna MAX_HEADDIM_SM8x = 192 From cbfde426aaead3717dc6297a6d826d9e993e36c5 Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 14 Jan 2026 11:58:03 -0500 Subject: [PATCH 21/27] update readme --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 96ca5ba521d..ae88267e0ca 100755 --- a/README.md +++ b/README.md @@ -141,12 +141,12 @@ These features are supported in Fwd and Bwd 6) Dropout 7) Rotary embeddings 8) ALiBi +9) Paged Attention +10) FP8 We are working on the following things -1) Paged Attention -2) Sliding Window -3) FP8 -4) Performance Improvements +1) Sliding Window +2) Performance Improvements ##### Getting Started To get started with the triton backend for AMD, follow the steps below. @@ -154,7 +154,7 @@ To get started with the triton backend for AMD, follow the steps below. First install the torch for ROCm from https://pytorch.org/get-started/locally/ if it is not installed. The torch and triton will be installed. ``` -pip install triton==3.4.0 +pip install triton==3.5.1 ``` Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. @@ -181,7 +181,7 @@ FROM rocm/pytorch:latest WORKDIR /workspace # install triton -RUN pip install triton==3.4.0 +RUN pip install triton==3.5.1 # build flash attention with triton backend RUN git clone https://github.com/Dao-AILab/flash-attention &&\ From 218c2eec21a268bf21a3891f70e7f5532bdc142b Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 14 Jan 2026 12:16:48 -0500 Subject: [PATCH 22/27] print tensor stats for debug --- .../flash_attn_triton_amd/interface_v2.py | 45 ++-- .../flash_attn_triton_amd/interface_v3.py | 213 ++++-------------- flash_attn/flash_attn_triton_amd/utils.py | 13 ++ 3 files changed, 76 insertions(+), 195 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/interface_v2.py b/flash_attn/flash_attn_triton_amd/interface_v2.py index d303ba63e7a..02f52d57930 100644 --- a/flash_attn/flash_attn_triton_amd/interface_v2.py +++ b/flash_attn/flash_attn_triton_amd/interface_v2.py @@ -12,6 +12,7 @@ PHILOX_OFFSET, SHAPE_EXPECTATIONS, round_multiple, + tensor_stats, ) @@ -46,11 +47,11 @@ def fwd( if DEBUG: print() print("flash_attn_triton_amd.py::fwd inputs") - print("q:", q.shape) - print("k:", k.shape) - print("v:", v.shape) - print("out:", out.shape if out is not None else None) - print("alibi_slopes:", alibi_slopes) + print(tensor_stats("q", q)) + print(tensor_stats("k", k)) + print(tensor_stats("v", v)) + print(tensor_stats("out", out)) + print(tensor_stats("alibi_slopes", alibi_slopes)) print("dropout_p:", dropout_p) print("softmax_scale:", softmax_scale) print("causal:", causal) @@ -161,9 +162,9 @@ def fwd( if DEBUG: print("flash_attn_triton_amd.py::fwd outputs") - print("o:", out.shape if out is not None else None) - print("softmax_lse:", softmax_lse.shape if softmax_lse is not None else None) - print("sd_mask:", sd_mask.shape if sd_mask is not None else None) + print(tensor_stats("out", out)) + print(tensor_stats("softmax_lse", softmax_lse)) + print(tensor_stats("sd_mask", sd_mask)) print("rng_state:", rng_state) # --- Assertions (shape + dtype contracts) --- @@ -235,24 +236,22 @@ def bwd( if DEBUG: print() print("flash_attn_triton_amd.py::bwd inputs") - print("dout:", dout, dout.shape) - print("q:", q.shape) - print("k:", k.shape) - print("v:", v.shape) - print("out:", out.shape) - print("softmax_lse:", softmax_lse.shape) - print("dq:", dq.shape if dq is not None else None) - print("dk:", dk.shape if dk is not None else None) - print("dv:", dv.shape if dv is not None else None) - print("alibi_slopes:", alibi_slopes) + print(tensor_stats("dout", dout)) + print(tensor_stats("q", q)) + print(tensor_stats("k", k)) + print(tensor_stats("v", v)) + print(tensor_stats("out", out)) + print(tensor_stats("softmax_lse", softmax_lse)) + print(tensor_stats("dq", dq)) + print(tensor_stats("dk", dk)) + print(tensor_stats("dv", dv)) + print(tensor_stats("alibi_slopes", alibi_slopes)) print("dropout_p:", dropout_p) - print("out:", out) print("softmax_scale:", softmax_scale) print("causal:", causal) print("window_size_left:", window_size_left) print("window_size_right:", window_size_right) print("deterministic:", deterministic) - print("gen_:", gen_) print("rng_state:", rng_state) dq = torch.zeros_like(q) if dq is None else dq.zero_() @@ -322,9 +321,9 @@ def bwd( if DEBUG: print("flash_attn_triton_amd.py::bwd outputs") - print("dv:", dv, dv.shape) - print("dk:", dk, dk.shape) - print("dq:", dq, dq.shape) + print(tensor_stats("dq", dq)) + print(tensor_stats("dk", dk)) + print(tensor_stats("dv", dv)) # --- Assertions --- assert dq.shape == q.shape, f"[bwd] dq shape {dq.shape} != q shape {q.shape}" assert dk.shape == k.shape, f"[bwd] dk shape {dk.shape} != k shape {k.shape}" diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index 5e95d3c661d..1ada416ed61 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -13,6 +13,7 @@ PHILOX_OFFSET, is_fp8, get_recommended_fp8_dtype, + tensor_stats, ) @@ -61,95 +62,29 @@ def fwd( if DEBUG: print() print("interface_fa_v3.py::fwd inputs") - print("q:", q.dtype if q is not None else None, q.shape) - print("k:", k.dtype if k is not None else None, k.shape) - print("v:", v.dtype if v is not None else None, v.shape) - print( - "k_new:", - k_new.dtype if k_new is not None else None, - k_new.shape if k_new is not None else None, - ) - print( - "v_new:", - v_new.dtype if v_new is not None else None, - v_new.shape if v_new is not None else None, - ) - print( - "qv:", - qv.dtype if qv is not None else None, - qv.shape if qv is not None else None, - ) - print( - "out:", - out.dtype if out is not None else None, - out.shape if out is not None else None, - ) - print( - "cu_seqlens_q:", - cu_seqlens_q, - cu_seqlens_q.shape if cu_seqlens_q is not None else None, - ) - print( - "cu_seqlens_k:", - cu_seqlens_k, - cu_seqlens_k.shape if cu_seqlens_k is not None else None, - ) - print( - "cu_seqlens_k_new:", - cu_seqlens_k_new, - cu_seqlens_k_new.shape if cu_seqlens_k_new is not None else None, - ) - print( - "seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None - ) - print( - "seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None - ) + print(tensor_stats("q", q)) + print(tensor_stats("k", k)) + print(tensor_stats("v", v)) + print(tensor_stats("k_new", k_new)) + print(tensor_stats("v_new", v_new)) + print(tensor_stats("qv", qv)) + print(tensor_stats("out", out)) + print(tensor_stats("cu_seqlens_q", cu_seqlens_q)) + print(tensor_stats("cu_seqlens_k", cu_seqlens_k)) + print(tensor_stats("cu_seqlens_k_new", cu_seqlens_k_new)) + print(tensor_stats("seqused_q", seqused_q)) + print(tensor_stats("seqused_k", seqused_k)) print("max_seqlen_q:", max_seqlen_q) print("max_seqlen_k:", max_seqlen_k) - print( - "page_table:", - page_table, - page_table.shape if page_table is not None else None, - ) - print( - "kv_batch_idx:", - kv_batch_idx, - kv_batch_idx.shape if kv_batch_idx is not None else None, - ) - print( - "leftpad_k:", leftpad_k, leftpad_k.shape if leftpad_k is not None else None - ) - print( - "rotary_cos:", - rotary_cos, - rotary_cos.shape if rotary_cos is not None else None, - ) - print( - "rotary_sin:", - rotary_sin, - rotary_sin.shape if rotary_sin is not None else None, - ) - print( - "seqlens_rotary:", - seqlens_rotary, - seqlens_rotary.shape if seqlens_rotary is not None else None, - ) - print( - "q_descale:", - q_descale.dtype if q_descale is not None else None, - q_descale.shape if q_descale is not None else None, - ) - print( - "k_descale:", - k_descale.dtype if k_descale is not None else None, - k_descale.shape if k_descale is not None else None, - ) - print( - "v_descale:", - v_descale.dtype if v_descale is not None else None, - v_descale.shape if v_descale is not None else None, - ) + print(tensor_stats("page_table", page_table)) + print(tensor_stats("kv_batch_idx", kv_batch_idx)) + print(tensor_stats("leftpad_k", leftpad_k)) + print(tensor_stats("rotary_cos", rotary_cos)) + print(tensor_stats("rotary_sin", rotary_sin)) + print(tensor_stats("seqlens_rotary", seqlens_rotary)) + print(tensor_stats("q_descale", q_descale)) + print(tensor_stats("k_descale", k_descale)) + print(tensor_stats("v_descale", v_descale)) print("softmax_scale:", softmax_scale) print("causal:", causal) print("window_size_left:", window_size_left) @@ -401,16 +336,8 @@ def fwd( if DEBUG: print("interface_fa_v3.py::fwd outputs") - print( - "out:", - out.dtype if out is not None else None, - out.shape if out is not None else None, - ) - print( - "softmax_lse:", - softmax_lse.dtype if softmax_lse is not None else None, - softmax_lse.shape if softmax_lse is not None else None, - ) + print(tensor_stats("out", out)) + print(tensor_stats("softmax_lse", softmax_lse)) # --- Assertions (FA3 always expects exact shapes) --- # out: same shape as q except last dim is v's head_dim @@ -494,61 +421,19 @@ def bwd( if DEBUG: print() print("interface_fa_v3.py::bwd inputs") - print( - "dout:", - dout.dtype if dout is not None else None, - dout.shape if dout is not None else None, - ) - print( - "q:", q.dtype if q is not None else None, q.shape if q is not None else None - ) - print( - "k:", k.dtype if k is not None else None, k.shape if k is not None else None - ) - print( - "v:", v.dtype if v is not None else None, v.shape if v is not None else None - ) - print( - "out:", - out.dtype if out is not None else None, - out.shape if out is not None else None, - ) - print( - "softmax_lse:", - softmax_lse.dtype if softmax_lse is not None else None, - softmax_lse.shape if softmax_lse is not None else None, - ) - print( - "dq:", - dq.dtype if dq is not None else None, - dq.shape if dq is not None else None, - ) - print( - "dk:", - dk.dtype if dk is not None else None, - dk.shape if dk is not None else None, - ) - print( - "dv:", - dv.dtype if dv is not None else None, - dv.shape if dv is not None else None, - ) - print( - "cu_seqlens_q:", - cu_seqlens_q, - cu_seqlens_q.shape if cu_seqlens_q is not None else None, - ) - print( - "cu_seqlens_k:", - cu_seqlens_k, - cu_seqlens_k.shape if cu_seqlens_k is not None else None, - ) - print( - "seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None - ) - print( - "seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None - ) + print(tensor_stats("dout", dout)) + print(tensor_stats("q", q)) + print(tensor_stats("k", k)) + print(tensor_stats("v", v)) + print(tensor_stats("out", out)) + print(tensor_stats("softmax_lse", softmax_lse)) + print(tensor_stats("dq", dq)) + print(tensor_stats("dk", dk)) + print(tensor_stats("dv", dv)) + print(tensor_stats("cu_seqlens_q", cu_seqlens_q)) + print(tensor_stats("cu_seqlens_k", cu_seqlens_k)) + print(tensor_stats("seqused_q", seqused_q)) + print(tensor_stats("seqused_k", seqused_k)) print("max_seqlen_q:", max_seqlen_q) print("max_seqlen_k:", max_seqlen_k) print("softmax_scale:", softmax_scale) @@ -637,26 +522,10 @@ def bwd( if DEBUG: print("interface_fa_v3.py::bwd outputs") - print( - "dq:", - dq.dtype if dq is not None else None, - dq.shape if dq is not None else None, - ) - print( - "dk:", - dk.dtype if dk is not None else None, - dk.shape if dk is not None else None, - ) - print( - "dv:", - dv.dtype if dv is not None else None, - dv.shape if dv is not None else None, - ) - print( - "delta:", - delta.dtype if delta is not None else None, - delta.shape if delta is not None else None, - ) + print(tensor_stats("dq", dq)) + print(tensor_stats("dk", dk)) + print(tensor_stats("dv", dv)) + print(tensor_stats("delta", delta)) # --- Assertions (FA3 always expects exact shapes) --- # Gradients should match input shapes diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 8c628002d5f..d4cfda18a23 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -45,6 +45,19 @@ SHAPE_EXPECTATIONS: Literal["exact", "rounded"] = "exact" +def tensor_stats(name: str, t: torch.Tensor) -> str: + """Return a string with tensor shape, dtype, and distribution stats for debugging.""" + if t is None: + return f"{name}: None" + flat = t.float().flatten() + return ( + f"{name}: shape={tuple(t.shape)}, dtype={t.dtype}, " + f"min={flat.min().item():.6g}, max={flat.max().item():.6g}, " + f"mean={flat.mean().item():.6g}, median={flat.median().item():.6g}, " + f"std={flat.std().item():.6g}" + ) + + # ------------------------------- # Input Helper # ------------------------------- From 5e3c7a79a37a7eddb5504aba93df0640236bddb1 Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 14 Jan 2026 12:42:19 -0500 Subject: [PATCH 23/27] disable sliding window tests --- README.md | 59 ++++--------------- .../flash_attn_triton_amd/interface_v2.py | 9 +++ .../flash_attn_triton_amd/interface_v3.py | 9 +++ tests/test_flash_attn_triton_amd.py | 56 ++++-------------- 4 files changed, 40 insertions(+), 93 deletions(-) diff --git a/README.md b/README.md index ae88267e0ca..406157fed60 100755 --- a/README.md +++ b/README.md @@ -128,54 +128,24 @@ FlashAttention-2 ROCm CK backend currently supports: 3. Both forward's and backward's head dimensions up to 256. #### Triton Backend -The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress. +The Triton implementation of [Flash Attention](https://tridao.me/publications/flash2/flash2.pdf) supports AMD's CDNA (MI200, MI300) and RDNA GPUs using fp16, bf16, and fp32 datatypes. It provides forward and backward passes with causal masking, variable sequence lengths, arbitrary Q/KV sequence lengths and head sizes, MQA/GQA, dropout, rotary embeddings, ALiBi, paged attention, and FP8 (via the Flash Attention v3 interface). Sliding window attention is currently a work in progress. -It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes. - -These features are supported in Fwd and Bwd -1) Fwd and Bwd with causal masking -2) Variable sequence lengths -3) Arbitrary Q and KV sequence lengths -4) Arbitrary head sizes -5) Multi and grouped query attention -6) Dropout -7) Rotary embeddings -8) ALiBi -9) Paged Attention -10) FP8 - -We are working on the following things -1) Sliding Window -2) Performance Improvements - -##### Getting Started -To get started with the triton backend for AMD, follow the steps below. - -First install the torch for ROCm from https://pytorch.org/get-started/locally/ if it is not installed. The torch and triton will be installed. - -``` +To install, first get PyTorch for ROCm from https://pytorch.org/get-started/locally/, then install Triton and Flash Attention: +```sh pip install triton==3.5.1 -``` -Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. - -``` cd flash-attention FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install ``` -To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing. -``` +To run the tests (note: full suite takes hours): +```sh FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py ``` -You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"` -``` -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE -``` +For better performance, enable autotune with `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"`. -###### Docker -You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image. -``` +For a quick start with Docker: +```dockerfile FROM rocm/pytorch:latest WORKDIR /workspace @@ -193,17 +163,12 @@ WORKDIR /workspace/flash-attention # set env variable to use triton backend ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - ``` -To build the docker file -``` -docker build -t fa_triton . -``` - -To run the docker image -``` -docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton +Build and run: +```sh +docker build -t flash-attn-triton . +docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri flash-attn-triton ``` ## How to use FlashAttention diff --git a/flash_attn/flash_attn_triton_amd/interface_v2.py b/flash_attn/flash_attn_triton_amd/interface_v2.py index 02f52d57930..d4803a2c252 100644 --- a/flash_attn/flash_attn_triton_amd/interface_v2.py +++ b/flash_attn/flash_attn_triton_amd/interface_v2.py @@ -233,6 +233,15 @@ def bwd( "softcap is not supported in the AMD Triton FA2 interface (expected 0.0)." ) + # Check for sliding window - backward doesn't support it yet + is_sliding_window = (window_size_left >= 0) or (window_size_right >= 0) + if is_sliding_window: + raise NotImplementedError( + f"Sliding window attention is not yet supported in the AMD Triton backward pass " + f"(window_size_left={window_size_left}, window_size_right={window_size_right}). " + f"Use window_size=(-1, -1) for full attention." + ) + if DEBUG: print() print("flash_attn_triton_amd.py::bwd inputs") diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index 1ada416ed61..077bb58319c 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -446,6 +446,15 @@ def bwd( # Check for unsupported features in backward pass + # Handle sliding window - backward doesn't support it yet + is_sliding_window = (window_size_left >= 0) or (window_size_right >= 0) + if is_sliding_window: + raise NotImplementedError( + f"Sliding window attention is not yet supported in the AMD Triton backward pass " + f"(window_size_left={window_size_left}, window_size_right={window_size_right}). " + f"Use window_size=(-1, -1) for full attention." + ) + # Handle softcap if softcap != 0.0: raise NotImplementedError( diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index 7f46b6d2813..b924890b115 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -573,7 +573,7 @@ def get_dropout_fraction( # @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) -@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) @@ -712,10 +712,6 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) - if local == True: - print("Sliding Window not supported in backward yet") - return - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() @@ -726,7 +722,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [True]) -@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @@ -864,10 +860,6 @@ def test_flash_attn_varlen_qkvpacked( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) - if local == True: - print("Sliding Window not supported in backward yet") - return - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() @@ -882,7 +874,7 @@ def test_flash_attn_varlen_qkvpacked( # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) -@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @@ -1141,10 +1133,6 @@ def test_flash_attn_output( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) - if local == True: - print("Sliding Window not supported in backward yet") - return - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() @@ -1161,7 +1149,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [True]) -@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [True]) @@ -1467,10 +1455,6 @@ def test_flash_attn_varlen_output( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04) - if local == True: - print("Sliding Window not supported in backward yet") - return - if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() @@ -1479,7 +1463,7 @@ def test_flash_attn_varlen_output( @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -1585,10 +1569,6 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 - if local == True: - print("Sliding Window not supported in backward yet") - return - assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 @@ -1596,7 +1576,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -1757,10 +1737,6 @@ def test_flash_attn_varlen_causal( # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 - if local == True: - print("Sliding Window not supported in backward yet") - return - if test_backward: assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 @@ -1773,7 +1749,7 @@ def test_flash_attn_varlen_causal( # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [True]) -@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @@ -1891,10 +1867,6 @@ def test_flash_attn_splitkv( # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 - if local == True: - print("Sliding Window not supported in backward yet") - return - mult = 2 if not alibi else 8 assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 @@ -1911,7 +1883,7 @@ def test_flash_attn_splitkv( # @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) -@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) @@ -2432,7 +2404,7 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @@ -2480,10 +2452,6 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True) - if local == True: - print("Sliding Window not supported in backward yet") - return - g = torch.randn_like(out) dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) for _ in range(50): @@ -2495,7 +2463,7 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @@ -2572,10 +2540,6 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus deterministic=True, ) - if local == True: - print("Sliding Window not supported in backward yet") - return - g = torch.randn_like(out) dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) for _ in range(50): From 537c5079accbfde095702c58265461f73dd50e69 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Mon, 19 Jan 2026 14:03:49 -0600 Subject: [PATCH 24/27] add rdna configs --- flash_attn/flash_attn_triton_amd/bwd.py | 117 +++++--- .../flash_attn_triton_amd/fwd_decode.py | 249 +++++++++--------- .../flash_attn_triton_amd/fwd_prefill.py | 63 ++--- 3 files changed, 236 insertions(+), 193 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index d2ed7aa113a..8f40953c395 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -9,38 +9,38 @@ AUTOTUNE, compute_fp8_scaling_factors, get_cu_count, - is_cdna, is_fp8, get_arch, ) -def get_bwd_configs(autotune: bool): - # keys - preprocess_autotune_keys = [ - "max_seqlen_q", - "ACTUAL_HEAD_DIM", - "IS_VARLEN", - ] +PREPROCESS_AUTOTUNE_KEYS = [ + "max_seqlen_q", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", +] + +CAUSAL_AUTOTUNE_KEYS = [ + "dropout_p", + "max_seqlen_q", + "max_seqlen_k", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", + "HQ", + "HK", +] + +NONCAUSAL_AUTOTUNE_KEYS = [ + "dropout_p", + "max_seqlen_q", + "max_seqlen_k", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", + "HQ", + "HK", +] - causal_autotune_keys = [ - "dropout_p", - "max_seqlen_q", - "max_seqlen_k", - "ACTUAL_HEAD_DIM", - "IS_VARLEN", - "HQ", - "HK", - ] - noncausal_autotune_keys = [ - "dropout_p", - "max_seqlen_q", - "max_seqlen_k", - "ACTUAL_HEAD_DIM", - "IS_VARLEN", - "HQ", - "HK", - ] +def get_bwd_configs(autotune: bool): # default config if not autotune: @@ -320,6 +320,47 @@ def get_bwd_configs(autotune: bool): num_warps=4, ), ] + elif arch in ( + "gfx1030", + "gfx1100", + "gfx1101", + "gfx1102", + "gfx1200", + "gfx1201", + ): # RDNA architectures + preprocess_autotune_configs = [ + triton.Config( + {"PRE_BLOCK": 128, "waves_per_eu": 1}, num_stages=1, num_warps=4 + ), + ] + noncausal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] + causal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] else: preprocess_autotune_configs = [ triton.Config( @@ -369,9 +410,9 @@ def get_bwd_configs(autotune: bool): ), f"BLOCK_N1 ({causal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({causal_cfg.all_kwargs()['BLOCK_M2']})" return ( - (preprocess_autotune_configs, preprocess_autotune_keys), - (causal_autotune_configs, causal_autotune_keys), - (noncausal_autotune_configs, noncausal_autotune_keys), + preprocess_autotune_configs, + causal_autotune_configs, + noncausal_autotune_configs, ) # param options @@ -485,16 +526,16 @@ def get_bwd_configs(autotune: bool): ) return ( - (preprocess_autotune_configs, preprocess_autotune_keys), - (causal_autotune_configs, causal_autotune_keys), - (noncausal_autotune_configs, noncausal_autotune_keys), + preprocess_autotune_configs, + causal_autotune_configs, + noncausal_autotune_configs, ) # os.environ["TRITON_PRINT_AUTOTUNING"] = "1" ( - (preprocess_autotune_configs, preprocess_autotune_keys), - (causal_autotune_configs, causal_autotune_keys), - (noncausal_autotune_configs, noncausal_autotune_keys), + preprocess_autotune_configs, + causal_autotune_configs, + noncausal_autotune_configs, ) = get_bwd_configs(AUTOTUNE) @@ -2476,7 +2517,7 @@ def _bwd_kernel_split_dq_noncausal( # Delta: (batch, nheads_q, max_seqlens_q) @triton.autotune( configs=preprocess_autotune_configs, - key=preprocess_autotune_keys, + key=PREPROCESS_AUTOTUNE_KEYS, use_cuda_graph=True, ) @triton.jit @@ -2888,7 +2929,7 @@ def _bwd_dq_inner( @triton.autotune( configs=causal_autotune_configs, - key=causal_autotune_keys, + key=CAUSAL_AUTOTUNE_KEYS, use_cuda_graph=True, ) @triton.jit @@ -3467,7 +3508,7 @@ def bwd_kernel_fused_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_ @triton.autotune( configs=noncausal_autotune_configs, - key=noncausal_autotune_keys, + key=NONCAUSAL_AUTOTUNE_KEYS, use_cuda_graph=True, ) @triton.jit diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index 4645dcc97fe..9b8b5aea3f5 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -11,107 +11,111 @@ get_padded_headsize, get_shape_and_strides_from_layout, apply_rotary, - is_cdna, is_fp8, get_recommended_fp8_dtype, ) -def get_cdna_autotune_configs(): - return [ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - # Fall-back config. - triton.Config( - {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - ], [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_BLOCK_DMODEL", - "VARLEN", - "HQ", - "HK", - ] - - -def get_autotune_configs(): - if AUTOTUNE: - if is_cdna(): - autotune_configs, autotune_keys = get_cdna_autotune_configs() - fwd_auto_tune_configs, fwd_autotune_keys = autotune_configs, autotune_keys - reduce_auto_tune_configs, reduce_autotune_keys = ( - autotune_configs, - autotune_keys, - ) - return (fwd_auto_tune_configs, fwd_autotune_keys), ( - reduce_auto_tune_configs, - reduce_autotune_keys, - ) +FWD_DECODE_AUTOTUNE_KEYS = [ + "N_CTX_Q", + "N_CTX_K", + "ACTUAL_BLOCK_DMODEL", + "H_q", + "H_kv", + "IS_CAUSAL", + "IS_GQA", +] + +# Maximum BLOCK_M across all configs (for intermediate tensor allocation) +MAX_BLOCK_M = 64 + + +def get_fwd_decode_configs(autotune: bool): + """ + Returns configs for both the splitK kernel and reduce kernel. + + Returns: + (splitk_configs, reduce_config): Tuple of config lists for each kernel + """ + + if not autotune: + arch = get_arch() + + if arch in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"): + # RDNA architectures + splitk_configs = [ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 2}, + num_stages=1, + num_warps=2, + ), + ] + reduce_configs = [triton.Config({}, num_stages=1, num_warps=2)] + elif arch in ("gfx940", "gfx941", "gfx942", "gfx950"): + # CDNA architectures + splitk_configs = [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1}, + num_stages=1, + num_warps=4, + ), + ] + reduce_configs = [triton.Config({}, num_stages=1, num_warps=4)] else: - raise ValueError("Unknown Device Type") - else: - autotune_configs, autotune_keys = [ - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - ], [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_BLOCK_DMODEL", - "VARLEN", - "HQ", - "HK", - ] + # Default / fallback + splitk_configs = [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1}, + num_stages=1, + num_warps=4, + ), + ] + reduce_configs = [triton.Config({}, num_stages=1, num_warps=4)] + + return splitk_configs, reduce_configs + + # ===================== Autotune Sweep ===================== + arch = get_arch() + splitk_configs = [] + + BLOCK_M_OPTIONS = [64, 32, 16] + BLOCK_N_OPTIONS = [128, 64, 32, 16] + NUM_WARPS_OPTIONS = [2, 4] + NUM_STAGES_OPTIONS = [1] + WAVES_PER_EU_OPTIONS = [4, 2, 1] + + # Ensure BLOCK_M options don't exceed MAX_BLOCK_M + assert all(bm <= MAX_BLOCK_M for bm in BLOCK_M_OPTIONS), \ + f"BLOCK_M_OPTIONS {BLOCK_M_OPTIONS} exceeds MAX_BLOCK_M {MAX_BLOCK_M}" + + for bm in BLOCK_M_OPTIONS: + for bn in BLOCK_N_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for nw in NUM_WARPS_OPTIONS: + for ns in NUM_STAGES_OPTIONS: + splitk_configs.append( + triton.Config( + { + "BLOCK_M": bm, + "BLOCK_N": bn, + "waves_per_eu": waves, + }, + num_stages=ns, + num_warps=nw, + ) + ) + + # Reduce kernel configs - sweep num_warps + NUM_WARPS_REDUCE_OPTIONS = [2, 4] + reduce_configs = [ + triton.Config({}, num_stages=1, num_warps=nw) + for nw in NUM_WARPS_REDUCE_OPTIONS + ] - fwd_auto_tune_configs, fwd_autotune_keys = autotune_configs, autotune_keys - reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys - return (fwd_auto_tune_configs, fwd_autotune_keys), ( - reduce_auto_tune_configs, - reduce_autotune_keys, - ) + return splitk_configs, reduce_configs -(fwd_auto_tune_configs, fwd_autotune_keys), ( - reduce_auto_tune_configs, - reduce_autotune_keys, -) = get_autotune_configs() +fwd_decode_splitk_configs, fwd_decode_reduce_configs = get_fwd_decode_configs(AUTOTUNE) @triton.jit @@ -128,18 +132,18 @@ def _attn_fwd_inner( q_descale, k_descale, v_descale, # FP8 scaling factors + alibi_slope, + apply_col_mask, IS_FP8: tl.constexpr, # FP8 flag BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, N_CTX_Q: tl.constexpr, N_CTX_K_FINAL: tl.constexpr, USE_ALIBI: tl.constexpr, - alibi_slope, USE_SLIDING_WINDOW: tl.constexpr, IS_CAUSAL: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, - APPLY_COL_MASK: tl.constexpr, # apply provided col_mask when True ): # -- compute qk --- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -204,7 +208,7 @@ def _attn_fwd_inner( # Column mask (tail / variable-length). Instead of recomputing an arange each time, # we accept a precomputed mask from the caller (col_valid_mask). - if APPLY_COL_MASK: + if apply_col_mask: # Expect col_mask shape: [BLOCK_N]. True where column is within sequence. qk = tl.where(col_mask[None, :], qk, float("-inf")) @@ -235,11 +239,11 @@ def _attn_fwd_inner( return m_i, l_i, acc -# @triton.autotune( -# configs=fwd_auto_tune_configs, -# key=fwd_autotune_keys, -# use_cuda_graph=True, -# ) +@triton.autotune( + configs=fwd_decode_splitk_configs, + key=FWD_DECODE_AUTOTUNE_KEYS, + use_cuda_graph=True, +) @triton.jit def _fwd_kernel_splitK( Q, @@ -313,7 +317,6 @@ def _fwd_kernel_splitK( BLOCK_DMODEL: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - BOUNDS_CHECKS_N: tl.constexpr, USE_CACHE_SEQLENs: tl.constexpr, USE_CACHE_BATCH_IDX: tl.constexpr, NEW_KV: tl.constexpr, @@ -540,22 +543,24 @@ def _fwd_kernel_splitK( q_descale, k_descale, v_descale, + alibi_slope, + True, IS_FP8, BLOCK_M, BLOCK_N, N_CTX_Q, N_CTX_K_FINAL, USE_ALIBI, - alibi_slope, USE_SLIDING_WINDOW, IS_CAUSAL, WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, - True, ) else: # Non-paged attention: process KV from cache # Note: Cache should be updated externally before calling this kernel + # Compute bounds check flag once: needed if split size not aligned to BLOCK_N or variable seqlens + bounds_checks_n = ((BLOCK_N_PER_SPLIT % BLOCK_N) > 0) | USE_CACHE_SEQLENs # loop over k, v and update accumulator for start_n in range(lo, hi, BLOCK_N): kT_ptrs = ( @@ -591,18 +596,18 @@ def _fwd_kernel_splitK( q_descale, k_descale, v_descale, + alibi_slope, + bounds_checks_n, IS_FP8, BLOCK_M, BLOCK_N, N_CTX_Q, N_CTX_K_FINAL, USE_ALIBI, - alibi_slope, USE_SLIDING_WINDOW, IS_CAUSAL, WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, - BOUNDS_CHECKS_N, ) # write back O @@ -623,11 +628,17 @@ def _fwd_kernel_splitK( tl.store(metadata_ptr + stride_m2, l_i) -# @triton.autotune( -# configs=reduce_auto_tune_configs, -# key=reduce_autotune_keys, -# use_cuda_graph=True, -# ) +FWD_DECODE_REDUCE_AUTOTUNE_KEYS = [ + "BLOCK_DMODEL", + "split_k", +] + + +@triton.autotune( + configs=fwd_decode_reduce_configs, + key=FWD_DECODE_REDUCE_AUTOTUNE_KEYS, + use_cuda_graph=True, +) @triton.jit def _splitK_reduce( Out_splitK, # [B*H*G, split_k, Mq, K] @@ -947,13 +958,6 @@ def attention_forward_decode_triton_impl( if cache_seqlens is not None: cache_seqlens[b] = start_idx + seqlen_new - # triton configs - BLOCK_M = 16 - BLOCK_N = 64 - num_stages = 1 - num_warps_fwd = 1 - num_warps_reduce = 4 - # kernel_configs is_new_kv = False # Cache has been updated, so no new KV in kernel use_alibi, (stride_az, stride_ah) = True if alibi_slopes is not None else False, ( @@ -1086,8 +1090,9 @@ def attention_forward_decode_triton_impl( split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) split_size = (seqlen_kc + split_k - 1) // split_k - # setup grid - seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M + # setup grid - use lambda to get BLOCK_M from autotune + # Use MAX_BLOCK_M for intermediate tensor allocation to ensure enough space + seqlen_q_ceil = (seqlen_q + MAX_BLOCK_M - 1) // MAX_BLOCK_M * MAX_BLOCK_M grid = lambda META: ( triton.cdiv(seqlen_q, META["BLOCK_M"]), batch_size * n_group_q * heads_per_group_q, @@ -1321,11 +1326,8 @@ def attention_forward_decode_triton_impl( N_CTX_NEW=0, # No new KV, cache already updated BLOCK_N_PER_SPLIT=split_size, BLOCK_SIZE_K=block_size_k if use_block_table else 256, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, BLOCK_DMODEL=dim_padded, ACTUAL_BLOCK_DMODEL=dim_kc, - BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, USE_CACHE_SEQLENs=use_cache_seqlens, USE_CACHE_BATCH_IDX=cache_batch_idx is not None, NEW_KV=False, # Cache already updated @@ -1339,8 +1341,6 @@ def attention_forward_decode_triton_impl( WINDOW_SIZE_RIGHT=window_size_right, USE_BLOCK_TABLE=use_block_table, IS_FP8=IS_FP8, - num_warps=num_warps_fwd, - num_stages=num_stages, ) if DEBUG: @@ -1400,5 +1400,4 @@ def attention_forward_decode_triton_impl( splitK_pow2=splitK_pow2, MASK_SPLITK=mask_split_k, PADDED_HEAD=is_padded_head, - num_warps=num_warps_reduce, ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 81a1de19f20..a4292602dad 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -11,7 +11,6 @@ compute_fp8_scaling_factors, get_arch, get_cu_count, - is_cdna, is_fp8, is_rdna, apply_rotary, @@ -20,19 +19,21 @@ -def get_fwd_configs(autotune: bool): +FWD_PREFILL_AUTOTUNE_KEYS = [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL_QK", + "ACTUAL_BLOCK_DMODEL_V", + "IS_VARLEN", + "HQ", + "HK", +] + + +def get_fwd_prefill_configs(autotune: bool): configs = [] - keys = [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_BLOCK_DMODEL_QK", - "ACTUAL_BLOCK_DMODEL_V", - "IS_VARLEN", - "HQ", - "HK", - ] # get best config for the architecture if not autotune: @@ -99,17 +100,19 @@ def get_fwd_configs(autotune: bool): "gfx1200", "gfx1201", ): # RDNA architectures - configs.append( - triton.Config( - { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=2, - ) + configs.extend( + [ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=2, + ), + ] ) else: configs.append( @@ -125,11 +128,11 @@ def get_fwd_configs(autotune: bool): ) ) - return configs, keys + return configs # ===================== Autotune Sweep ===================== - BLOCK_M_OPTIONS = [128, 64, 32] - BLOCK_N_OPTIONS = [128, 64, 32] + BLOCK_M_OPTIONS = [128, 64, 32, 16] + BLOCK_N_OPTIONS = [128, 64, 32, 16] NUM_WARPS_OPTIONS = [2, 4, 8] NUM_STAGES_OPTIONS = [1, 2] WAVES_PER_EU_OPTIONS = [4, 2, 1] @@ -153,10 +156,10 @@ def get_fwd_configs(autotune: bool): ) ) - return configs, keys + return configs -fwd_prefill_autotune_configs, fwd_prefill_autotune_keys = get_fwd_configs(AUTOTUNE) +fwd_prefill_autotune_configs = get_fwd_prefill_configs(AUTOTUNE) @triton.jit @@ -986,7 +989,7 @@ def compute_block_masking( @triton.autotune( configs=fwd_prefill_autotune_configs, - key=fwd_prefill_autotune_keys, + key=FWD_PREFILL_AUTOTUNE_KEYS, use_cuda_graph=True, ) @triton.jit From d625bbd5a4a715ba62826f3745b647cd60c6a020 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Wed, 21 Jan 2026 10:00:47 -0600 Subject: [PATCH 25/27] fix k partial bug --- flash_attn/flash_attn_triton_amd/bwd.py | 54 ++++----- .../flash_attn_triton_amd/fwd_decode.py | 56 +++++----- .../flash_attn_triton_amd/fwd_prefill.py | 104 +++++++----------- 3 files changed, 96 insertions(+), 118 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 8f40953c395..55914b899ae 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -45,10 +45,11 @@ def get_bwd_configs(autotune: bool): # default config if not autotune: arch = get_arch() + # configs for the kernels if arch == "gfx942": if get_cu_count() < 304: - preprocess_autotune_configs = [ + preprocess_configs = [ triton.Config( {"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=8 ), @@ -59,7 +60,7 @@ def get_bwd_configs(autotune: bool): {"PRE_BLOCK": 128, "waves_per_eu": 2}, num_stages=1, num_warps=4 ), ] - noncausal_autotune_configs = [ + noncausal_configs = [ triton.Config( { "BLOCK_M1": 32, @@ -113,7 +114,7 @@ def get_bwd_configs(autotune: bool): num_warps=8, ), ] - causal_autotune_configs = [ + causal_configs = [ triton.Config( { "BLOCK_M1": 32, @@ -155,7 +156,7 @@ def get_bwd_configs(autotune: bool): ), ] else: - preprocess_autotune_configs = [ + preprocess_configs = [ triton.Config( {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 ), @@ -163,7 +164,7 @@ def get_bwd_configs(autotune: bool): {"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=4 ), ] - noncausal_autotune_configs = [ + noncausal_configs = [ triton.Config( { "BLOCK_M1": 32, @@ -204,7 +205,7 @@ def get_bwd_configs(autotune: bool): num_warps=4, ), ] - causal_autotune_configs = [ + causal_configs = [ triton.Config( { "BLOCK_M1": 32, @@ -233,7 +234,7 @@ def get_bwd_configs(autotune: bool): ), ] elif arch == "gfx950": - preprocess_autotune_configs = [ + preprocess_configs = [ triton.Config( {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 ), @@ -244,7 +245,7 @@ def get_bwd_configs(autotune: bool): {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=4 ), ] - noncausal_autotune_configs = [ + noncausal_configs = [ triton.Config( { "BLOCK_M1": 64, @@ -294,7 +295,7 @@ def get_bwd_configs(autotune: bool): num_warps=4, ), ] - causal_autotune_configs = [ + causal_configs = [ triton.Config( { "BLOCK_M1": 32, @@ -328,46 +329,44 @@ def get_bwd_configs(autotune: bool): "gfx1200", "gfx1201", ): # RDNA architectures - preprocess_autotune_configs = [ + preprocess_configs = [ triton.Config( - {"PRE_BLOCK": 128, "waves_per_eu": 1}, num_stages=1, num_warps=4 + {"PRE_BLOCK": 32}, num_stages=1, num_warps=4 ), ] - noncausal_autotune_configs = [ + noncausal_configs = [ triton.Config( { "BLOCK_M1": 32, - "BLOCK_N1": 128, - "BLOCK_M2": 128, + "BLOCK_N1": 32, + "BLOCK_M2": 32, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, - "waves_per_eu": 1, }, num_stages=1, num_warps=4, ), ] - causal_autotune_configs = [ + causal_configs = [ triton.Config( { "BLOCK_M1": 32, - "BLOCK_N1": 128, - "BLOCK_M2": 128, + "BLOCK_N1": 32, + "BLOCK_M2": 32, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, - "waves_per_eu": 1, }, num_stages=1, num_warps=4, ), ] else: - preprocess_autotune_configs = [ + preprocess_configs = [ triton.Config( {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 ), ] - noncausal_autotune_configs = [ + noncausal_configs = [ triton.Config( { "BLOCK_M1": 32, @@ -381,7 +380,7 @@ def get_bwd_configs(autotune: bool): num_warps=4, ), ] - causal_autotune_configs = [ + causal_configs = [ triton.Config( { "BLOCK_M1": 32, @@ -397,9 +396,7 @@ def get_bwd_configs(autotune: bool): ] # assert constraints - for noncausal_cfg, causal_cfg in zip( - noncausal_autotune_configs, causal_autotune_configs - ): + for noncausal_cfg, causal_cfg in zip(noncausal_configs, causal_configs): assert ( noncausal_cfg.all_kwargs()["BLOCK_N1"] == noncausal_cfg.all_kwargs()["BLOCK_M2"] @@ -409,12 +406,9 @@ def get_bwd_configs(autotune: bool): == causal_cfg.all_kwargs()["BLOCK_M2"] ), f"BLOCK_N1 ({causal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({causal_cfg.all_kwargs()['BLOCK_M2']})" - return ( - preprocess_autotune_configs, - causal_autotune_configs, - noncausal_autotune_configs, - ) + return (preprocess_configs, causal_configs, noncausal_configs) + # ===================== Autotune Sweep ===================== # param options PRE_BLOCK_OPTIONS = [64, 128] # og: 128 PRE_WAVES_PER_EU_OPTIONS = [1, 2] diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index 9b8b5aea3f5..e901da29ad6 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -43,36 +43,40 @@ def get_fwd_decode_configs(autotune: bool): if arch in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"): # RDNA architectures - splitk_configs = [ - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 2}, - num_stages=1, - num_warps=2, - ), - ] - reduce_configs = [triton.Config({}, num_stages=1, num_warps=2)] + return ( + [ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32}, + num_stages=1, + num_warps=4, + ), + ], + [triton.Config({}, num_stages=1, num_warps=4)], + ) elif arch in ("gfx940", "gfx941", "gfx942", "gfx950"): # CDNA architectures - splitk_configs = [ - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1}, - num_stages=1, - num_warps=4, - ), - ] - reduce_configs = [triton.Config({}, num_stages=1, num_warps=4)] + return ( + [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1}, + num_stages=1, + num_warps=4, + ), + ], + [triton.Config({}, num_stages=1, num_warps=4)], + ) else: # Default / fallback - splitk_configs = [ - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1}, - num_stages=1, - num_warps=4, - ), - ] - reduce_configs = [triton.Config({}, num_stages=1, num_warps=4)] - - return splitk_configs, reduce_configs + return ( + [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1}, + num_stages=1, + num_warps=4, + ), + ], + [triton.Config({}, num_stages=1, num_warps=4)], + ) # ===================== Autotune Sweep ===================== arch = get_arch() diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index a4292602dad..1729b5e5d68 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -33,13 +33,11 @@ def get_fwd_prefill_configs(autotune: bool): - configs = [] - # get best config for the architecture if not autotune: arch = get_arch() if arch == "gfx950": - configs.append( + return [ triton.Config( { "BLOCK_M": 128, @@ -50,37 +48,35 @@ def get_fwd_prefill_configs(autotune: bool): num_stages=1, num_warps=4, ) - ) + ] elif arch == "gfx942": if get_cu_count() < 304: - configs.extend( - [ - # best fp8 config - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - # best f16 config - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 32, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=2, - num_warps=4, - ), - ] - ) + return [ + # best fp8 config + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + # best f16 config + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 32, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=2, + num_warps=4, + ), + ] else: - configs.append( + return [ triton.Config( { "BLOCK_M": 128, @@ -91,7 +87,7 @@ def get_fwd_prefill_configs(autotune: bool): num_stages=1, num_warps=4, ) - ) + ] elif arch in ( "gfx1030", "gfx1100", @@ -100,22 +96,15 @@ def get_fwd_prefill_configs(autotune: bool): "gfx1200", "gfx1201", ): # RDNA architectures - configs.extend( - [ - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - ] - ) + return [ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + ] else: - configs.append( + return [ triton.Config( { "BLOCK_M": 64, @@ -126,11 +115,10 @@ def get_fwd_prefill_configs(autotune: bool): num_stages=1, num_warps=4, ) - ) - - return configs + ] # ===================== Autotune Sweep ===================== + configs = [] BLOCK_M_OPTIONS = [128, 64, 32, 16] BLOCK_N_OPTIONS = [128, 64, 32, 16] NUM_WARPS_OPTIONS = [2, 4, 8] @@ -441,12 +429,8 @@ def _attn_fwd_mask( # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n - # TODO: This can be optimized to only be true for the padded block. - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. - # last step might get wasted but that is okay. check if this masking works For - # that case. + # If this is the last block / iteration, we want to mask if the + # sequence length is not a multiple of block size. if (n_extra_tokens != 0) and (start_n + BLOCK_N == block_max): boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32) size_n = start_n + offs_n[None, :] @@ -942,13 +926,9 @@ def compute_block_masking( # – the back side can require several masked blocks: # • intersection of the causal diagonal with K-grid # (at most ⌈BLOCK_M / BLOCK_N⌉ blocks) - # • plus one extra block if this Q-block stops in the - # middle of a K-block or the last K-block is padded + # • plus one for partial K blocks at the causal boundary # ------------------------------------------------------------ - padded_last_k = n_extra_tokens != 0 - is_modulo_mn = (not padded_last_k) & (seqlen_q % BLOCK_M == 0) - - n_back_masked_blocks = BLOCK_M // BLOCK_N + tl.where(is_modulo_mn, 0, 1) + n_back_masked_blocks = BLOCK_M // BLOCK_N + 1 n_back_masked_blocks = tl.minimum(n_back_masked_blocks, n_visible_k_blocks) n_front_skip_blocks = 0 # causal never skips the left side From 38288d11dae3ce55b240b481ade5a309d413f8b1 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Wed, 21 Jan 2026 21:29:14 -0600 Subject: [PATCH 26/27] fix block_size_n bug --- flash_attn/flash_attn_triton_amd/bwd.py | 11 +- .../flash_attn_triton_amd/fwd_decode.py | 6 +- .../flash_attn_triton_amd/fwd_prefill.py | 658 +++++------------- flash_attn/flash_attn_triton_amd/utils.py | 60 +- tests/test_flash_attn_triton_amd.py | 18 +- 5 files changed, 241 insertions(+), 512 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 55914b899ae..589a3ffeebf 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -47,7 +47,7 @@ def get_bwd_configs(autotune: bool): arch = get_arch() # configs for the kernels - if arch == "gfx942": + if arch.name == "gfx942": if get_cu_count() < 304: preprocess_configs = [ triton.Config( @@ -321,14 +321,7 @@ def get_bwd_configs(autotune: bool): num_warps=4, ), ] - elif arch in ( - "gfx1030", - "gfx1100", - "gfx1101", - "gfx1102", - "gfx1200", - "gfx1201", - ): # RDNA architectures + elif arch.is_rdna: preprocess_configs = [ triton.Config( {"PRE_BLOCK": 32}, num_stages=1, num_warps=4 diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index e901da29ad6..f31e28ffb89 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -41,8 +41,7 @@ def get_fwd_decode_configs(autotune: bool): if not autotune: arch = get_arch() - if arch in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"): - # RDNA architectures + if arch.is_rdna: return ( [ triton.Config( @@ -53,8 +52,7 @@ def get_fwd_decode_configs(autotune: bool): ], [triton.Config({}, num_stages=1, num_warps=4)], ) - elif arch in ("gfx940", "gfx941", "gfx942", "gfx950"): - # CDNA architectures + elif arch.is_cdna: return ( [ triton.Config( diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 1729b5e5d68..ff2b3d71131 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -12,7 +12,6 @@ get_arch, get_cu_count, is_fp8, - is_rdna, apply_rotary, get_recommended_fp8_dtype, ) @@ -33,15 +32,19 @@ def get_fwd_prefill_configs(autotune: bool): - # get best config for the architecture + # Get best config for the architecture. + # NOTE: Tests expect specific BLOCK_N sizes for attention score renormalization: + # - CDNA: BLOCK_N=64 + # - RDNA: BLOCK_N=32 + # See _get_block_size_n_triton() in test_flash_attn_triton_amd.py if not autotune: arch = get_arch() - if arch == "gfx950": + if arch.name == "gfx950": return [ triton.Config( { "BLOCK_M": 128, - "BLOCK_N": 128, + "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False, }, @@ -49,10 +52,9 @@ def get_fwd_prefill_configs(autotune: bool): num_warps=4, ) ] - elif arch == "gfx942": + elif arch.name == "gfx942": if get_cu_count() < 304: return [ - # best fp8 config triton.Config( { "BLOCK_M": 128, @@ -63,17 +65,6 @@ def get_fwd_prefill_configs(autotune: bool): num_stages=1, num_warps=4, ), - # best f16 config - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 32, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=2, - num_warps=4, - ), ] else: return [ @@ -88,14 +79,7 @@ def get_fwd_prefill_configs(autotune: bool): num_warps=4, ) ] - elif arch in ( - "gfx1030", - "gfx1100", - "gfx1101", - "gfx1102", - "gfx1200", - "gfx1201", - ): # RDNA architectures + elif arch.is_rdna: return [ triton.Config( {"BLOCK_M": 32, "BLOCK_N": 32, "PRE_LOAD_V": False}, @@ -151,195 +135,7 @@ def get_fwd_prefill_configs(autotune: bool): @triton.jit -def _attn_fwd_no_mask( - acc, - l_i, - m_i, - q, - k_base_ptrs, - v_base_ptrs, - bias_base_ptrs, - stride_kn, - stride_vk, - stride_bn, - stride_sn, - stride_sm, - start_m, - seqlen_k, - seqlen_q, - dropout_p, - philox_seed, - philox_offset_base, - sd_mask, - stride_sz, - stride_sh, - off_z, - off_h_q, - offs_m, - offs_n, - offs_d_qk, - offs_d_v, - block_min, - block_max, - alibi_slope, - q_descale, - k_descale, - v_descale, - IS_FP8: tl.constexpr, - FP8_MAX: tl.constexpr, - FP8_P_DESCALE: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL_QK: tl.constexpr, - BLOCK_DMODEL_V: tl.constexpr, - BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - PADDED_HEAD_QK: tl.constexpr, - PADDED_HEAD_V: tl.constexpr, - ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, - ACTUAL_BLOCK_DMODEL_V: tl.constexpr, - SM_SCALE: tl.constexpr, - USE_ALIBI: tl.constexpr, - USE_EXP2: tl.constexpr, - RETURN_SCORES: tl.constexpr, - ACCUMULATOR_TYPE, -): - if USE_EXP2: - RCP_LN2: tl.constexpr = 1.4426950408889634 - - # loop over k, v, and update accumulator - for start_n in range(block_min, block_max, BLOCK_N): - # get ptrs - k_ptrs = k_base_ptrs + start_n * stride_kn - v_ptrs = v_base_ptrs + start_n * stride_vk - - kv_offs_n = start_n + tl.arange(0, BLOCK_N) - # Load K - if PADDED_HEAD_QK: - k_mask = offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK - k = tl.load(k_ptrs, mask=k_mask, other=0.0) - else: - k = tl.load(k_ptrs) - - # Optionally preload V - if PRE_LOAD_V: - if PADDED_HEAD_V: - v_mask = offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V - v = tl.load(v_ptrs, mask=v_mask, other=0.0) - else: - v = tl.load(v_ptrs) - - # setup qk accumlator - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) - - # -- compute qk ---- - if IS_FP8: - qk += tl.dot(q, k) * q_descale * k_descale - else: - qk += tl.dot(q, k) - qk_scaled = qk * SM_SCALE - - if USE_ALIBI: - # compute the global position of each token within the sequence - q_offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - alibi_block = compute_alibi_block( - alibi_slope, seqlen_q, seqlen_k, q_offs_m, kv_offs_n - ) - qk_scaled += alibi_block - - # compute qk mask - qk_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) - - # compute bias - if bias_base_ptrs is not None: - bias_ptrs = bias_base_ptrs + start_n * stride_bn - bias = tl.load(bias_ptrs, mask=qk_mask, other=0.0) - qk_scaled += bias - - # get max scores so far - m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) - - # scale and subtract max - q_shifted = tl.where( - m_ij[:, None] == float("-inf"), float("-inf"), qk_scaled - m_ij[:, None] - ) - - # Compute scaled QK and softmax probabilities - if USE_EXP2: - p = tl.math.exp2(q_shifted * RCP_LN2) - else: - p = tl.math.exp(q_shifted) - - # CAVEAT: Must update l_ij before applying dropout - l_ij = tl.sum(p, 1) - if ENABLE_DROPOUT: - # Compute pointers for this block - philox_base = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh - philox_ptrs = philox_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn - - # compute dropout mask - rng_output = tl.rand(philox_seed, philox_ptrs) - dropout_mask = rng_output > dropout_p - - # return scores with negative values for dropped vals (only if RETURN_SCORES is True) - if RETURN_SCORES: - sd_mask_value = tl.where(dropout_mask, p, -p) - sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh - sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn - - # Compute mask for sd_mask storage - sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) - tl.store(sd_mask_ptrs, sd_mask_value, mask=sd_store_mask) - - # apply dropout mask in place - p = tl.where(dropout_mask, p, 0.0) - elif RETURN_SCORES: - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh - sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn - - # Compute mask for sd_mask storage - sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) - tl.store(sd_mask_ptrs, p, mask=sd_store_mask) - - # -- update output accumulator -- - # alpha is an adjustment factor for acc and li as we loop and find new maxes - # store the diff in maxes to adjust acc and li as we discover new maxes - m_diff = tl.where(m_ij == float("-inf"), float("-inf"), m_i - m_ij) - if USE_EXP2: - alpha = tl.math.exp2(m_diff * RCP_LN2) - else: - alpha = tl.math.exp(m_diff) - acc = acc * alpha[:, None] - if not PRE_LOAD_V: - if PADDED_HEAD_V: - v_mask = offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V - v = tl.load(v_ptrs, mask=v_mask, other=0.0) - else: - v = tl.load(v_ptrs) - - # -- update m_i and l_i - l_i = l_i * alpha + l_ij - m_i = m_ij - - if IS_FP8: - if FP8_P_DESCALE: - scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - acc += ( - tl.dot((p * scale_p).to(v.type.element_ty), v) - * descale_p - * v_descale - ) - else: - acc += tl.dot(p.to(v.type.element_ty), v) * v_descale - else: - acc += tl.dot(p.to(v.type.element_ty), v) - - return acc, l_i, m_i - - -@triton.jit -def _attn_fwd_mask( +def _attn_fwd_inner( acc, l_i, m_i, @@ -377,6 +173,7 @@ def _attn_fwd_mask( IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, + APPLY_MASK: tl.constexpr, # True for masked blocks, False for full blocks IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL_QK: tl.constexpr, @@ -397,10 +194,17 @@ def _attn_fwd_mask( WINDOW_SIZE_RIGHT: tl.constexpr, ACCUMULATOR_TYPE, ): + """ + Unified attention forward inner loop. + + APPLY_MASK controls whether causal/window masking is applied: + - False: Fast path for full blocks (no masking overhead) + - True: Masked path with causal/window masking support + """ if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 - # seqlen diff + # seqlen diff (only used when APPLY_MASK=True) seqlen_delta_qk = seqlen_k - seqlen_q # loop over k, v, and update accumulator @@ -409,33 +213,44 @@ def _attn_fwd_mask( k_ptrs = k_base_ptrs + start_n * stride_kn v_ptrs = v_base_ptrs + start_n * stride_vk - # For padded blocks, we will overrun the tensor size if - # we load all BLOCK_N. For others, the blocks are all within range. kv_offs_n = start_n + tl.arange(0, BLOCK_N) - k_mask = kv_offs_n[None, :] < seqlen_k - v_mask = kv_offs_n[:, None] < seqlen_k - if PADDED_HEAD_QK: - k_mask = k_mask & (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK) - if PADDED_HEAD_V: - v_mask = v_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) - - # load k and if preload_v then v - k = tl.load(k_ptrs, mask=k_mask, other=0.0) - if PRE_LOAD_V: - v = tl.load(v_ptrs, mask=v_mask, other=0.0) + + # Load K - different masking for APPLY_MASK vs non-masked + if APPLY_MASK: + # For masked blocks, check seqlen bounds + k_mask = kv_offs_n[None, :] < seqlen_k + v_mask = kv_offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + k_mask = k_mask & (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK) + if PADDED_HEAD_V: + v_mask = v_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) + k = tl.load(k_ptrs, mask=k_mask, other=0.0) + if PRE_LOAD_V: + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + else: + # For full blocks, only check head dimension padding + if PADDED_HEAD_QK: + k_mask = offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK + k = tl.load(k_ptrs, mask=k_mask, other=0.0) + else: + k = tl.load(k_ptrs) + if PRE_LOAD_V: + if PADDED_HEAD_V: + v_mask = offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + else: + v = tl.load(v_ptrs) - # setup qk accumlator + # setup qk accumulator qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) - # We start from end of seqlen_k so only the first iteration would need - # to be checked for padding if it is not a multiple of block_n - # If this is the last block / iteration, we want to mask if the - # sequence length is not a multiple of block size. - if (n_extra_tokens != 0) and (start_n + BLOCK_N == block_max): - boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32) - size_n = start_n + offs_n[None, :] - mask = size_n < boundary_m[:, None] - qk = tl.where(mask, qk, float("-inf")) + # Apply extra token masking for partial blocks (only when APPLY_MASK=True) + if APPLY_MASK: + if (n_extra_tokens != 0) and (start_n + BLOCK_N == block_max): + boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32) + size_n = start_n + offs_n[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) # -- compute qk ---- if IS_FP8: @@ -452,101 +267,62 @@ def _attn_fwd_mask( ) qk_scaled += alibi_block - if USE_SLIDING_WINDOW: - if IS_CAUSAL: - # ========== CAUSAL SLIDING WINDOW MASKING ========== - # For causal sliding window, we need to apply both constraints: - # 1. Causal: col_idx <= row_idx + (seqlen_k - seqlen_q) - # 2. Sliding window: row_idx - window_left <= col_idx <= row_idx + window_right - - # Get positions - row_idx = offs_m # Query positions - col_idx = kv_offs_n # Key positions - - # Expand for broadcasting - row_idx_expanded = row_idx[:, None] # [BLOCK_M, 1] - col_idx_expanded = col_idx[None, :] # [1, BLOCK_N] + # Apply causal/sliding window masking (only when APPLY_MASK=True) + if APPLY_MASK: + if USE_SLIDING_WINDOW: + if IS_CAUSAL: + # ========== CAUSAL SLIDING WINDOW MASKING ========== + row_idx = offs_m + col_idx = kv_offs_n + row_idx_expanded = row_idx[:, None] + col_idx_expanded = col_idx[None, :] - # Apply causal constraint: can only attend to positions before or at the diagonal - causal_offset = seqlen_k - seqlen_q - causal_mask = col_idx_expanded > (row_idx_expanded + causal_offset) + causal_offset = seqlen_k - seqlen_q + causal_mask = col_idx_expanded > (row_idx_expanded + causal_offset) - # Apply sliding window constraint - if WINDOW_SIZE_LEFT < 0: - # Only right window constraint - window_mask = col_idx_expanded > ( - row_idx_expanded + causal_offset + WINDOW_SIZE_RIGHT - ) + if WINDOW_SIZE_LEFT < 0: + window_mask = col_idx_expanded > ( + row_idx_expanded + causal_offset + WINDOW_SIZE_RIGHT + ) + else: + left_bound = row_idx_expanded + causal_offset - WINDOW_SIZE_LEFT + right_bound = row_idx_expanded + causal_offset + WINDOW_SIZE_RIGHT + window_mask = (col_idx_expanded < left_bound) | ( + col_idx_expanded > right_bound + ) + + mask = causal_mask | window_mask + qk_scaled = tl.where(mask, float("-inf"), qk_scaled) else: - # Both left and right window constraints - # Adjust window bounds by causal offset - left_bound = row_idx_expanded + causal_offset - WINDOW_SIZE_LEFT - right_bound = row_idx_expanded + causal_offset + WINDOW_SIZE_RIGHT - - # Can't attend to positions outside the window - window_mask = (col_idx_expanded < left_bound) | ( - col_idx_expanded > right_bound - ) - - # Final mask is the union of both constraints (True = cannot attend) - mask = causal_mask | window_mask + # ========== NON-CAUSAL SLIDING WINDOW MASKING ========== + row_idx = offs_m + col_idx = kv_offs_n + sk = seqlen_k + sq = seqlen_q + row_idx_expanded = row_idx[:, None] + col_idx_expanded = col_idx[None, :] - # Apply mask - qk_scaled = tl.where(mask, float("-inf"), qk_scaled) + if WINDOW_SIZE_LEFT < 0: + mask = col_idx_expanded > ( + row_idx_expanded + sk - sq + WINDOW_SIZE_RIGHT + ) + else: + sk_full = tl.full((1, BLOCK_N), sk, dtype=tl.int32) + right_bound_val = row_idx_expanded + sk - sq + WINDOW_SIZE_RIGHT + right_bound = tl.minimum(right_bound_val, sk_full) + left_bound = row_idx_expanded + sk - sq - WINDOW_SIZE_LEFT + mask = (col_idx_expanded > right_bound) | ( + col_idx_expanded < left_bound + ) + + qk_scaled = tl.where(mask, float("-inf"), qk_scaled) else: - # ========== NON-CAUSAL SLIDING WINDOW MASKING ========== - # Exactly matching reference construct_local_mask: - # row_idx = query positions, col_idx = key positions - # sk = seqlen_k, sq = seqlen_q - - # Get positions - row_idx = offs_m # Query positions - col_idx = kv_offs_n # Key positions - - # sk and sq from reference (no padding masks in this test) - sk = seqlen_k - sq = seqlen_q - - # Expand for broadcasting - row_idx_expanded = row_idx[:, None] # [BLOCK_M, 1] - col_idx_expanded = col_idx[None, :] # [1, BLOCK_N] - - # Reference logic for mask computation - if WINDOW_SIZE_LEFT < 0: - # Reference: return col_idx > row_idx + sk - sq + window_size[1] - mask = col_idx_expanded > ( - row_idx_expanded + sk - sq + WINDOW_SIZE_RIGHT - ) - else: - # Reference: - # sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk - # return torch.logical_or( - # col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), - # col_idx < row_idx + sk - sq - window_size[0], - # ) - # Create sk tensor with proper shape for broadcasting - # sk represents the key sequence length, which should be compared per column - sk_full = tl.full((1, BLOCK_N), sk, dtype=tl.int32) - - # Compute boundaries - right_bound_val = row_idx_expanded + sk - sq + WINDOW_SIZE_RIGHT - right_bound = tl.minimum(right_bound_val, sk_full) - left_bound = row_idx_expanded + sk - sq - WINDOW_SIZE_LEFT - - # Mask where True = cannot attend (matching reference) - mask = (col_idx_expanded > right_bound) | ( - col_idx_expanded < left_bound - ) - - # Apply mask (set to -inf where mask is True) - qk_scaled = tl.where(mask, float("-inf"), qk_scaled) - else: - if IS_CAUSAL: - causal_boundary = start_n + offs_n - seqlen_delta_qk - causal_mask = offs_m[:, None] >= causal_boundary[None, :] - qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n - seqlen_delta_qk + causal_mask = offs_m[:, None] >= causal_boundary[None, :] + qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) - # compute qk mask + # compute qk mask for bounds checking qk_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) # compute bias @@ -559,17 +335,10 @@ def _attn_fwd_mask( m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) # scale and subtract max - # IMPORTANT: Handle the case where all values are -inf - # When m_ij = -inf and qk_scaled = -inf, subtraction gives NaN - # We need to handle this explicitly - if USE_SLIDING_WINDOW: - # Check if this block has any valid values (m_ij != -inf) - # For rows where everything is -inf, set q_shifted to -inf (not NaN) - q_shifted = tl.where( - m_ij[:, None] == float("-inf"), float("-inf"), qk_scaled - m_ij[:, None] - ) - else: - q_shifted = qk_scaled - m_ij[:, None] + # Handle the case where all values are -inf + q_shifted = tl.where( + m_ij[:, None] == float("-inf"), float("-inf"), qk_scaled - m_ij[:, None] + ) # Compute scaled QK and softmax probabilities if USE_EXP2: @@ -594,23 +363,16 @@ def _attn_fwd_mask( sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn - # Compute mask for sd_mask storage - include bounds check sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) - # Add causal mask if applicable to prevent writing to invalid positions - if IS_CAUSAL: - seqlen_delta_qk = seqlen_k - seqlen_q + if APPLY_MASK and IS_CAUSAL: causal_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk) sd_store_mask = sd_store_mask & causal_constraint - # Add sliding window mask if applicable - if USE_SLIDING_WINDOW: - seqlen_delta_qk = seqlen_k - seqlen_q + if APPLY_MASK and USE_SLIDING_WINDOW: if WINDOW_SIZE_LEFT < 0: - # Only right window constraint window_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT) else: - # Both left and right window constraints left_bound = offs_m[:, None] + seqlen_delta_qk - WINDOW_SIZE_LEFT right_bound = offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT window_constraint = (kv_offs_n[None, :] >= left_bound) & (kv_offs_n[None, :] <= right_bound) @@ -621,27 +383,19 @@ def _attn_fwd_mask( # apply dropout mask in place p = tl.where(dropout_mask, p, 0.0) elif RETURN_SCORES: - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn - # Compute mask for sd_mask storage - include bounds check sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) - # Add causal mask if applicable - if IS_CAUSAL: - seqlen_delta_qk = seqlen_k - seqlen_q + if APPLY_MASK and IS_CAUSAL: causal_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk) sd_store_mask = sd_store_mask & causal_constraint - # Add sliding window mask if applicable - if USE_SLIDING_WINDOW: - seqlen_delta_qk = seqlen_k - seqlen_q + if APPLY_MASK and USE_SLIDING_WINDOW: if WINDOW_SIZE_LEFT < 0: - # Only right window constraint window_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT) else: - # Both left and right window constraints left_bound = offs_m[:, None] + seqlen_delta_qk - WINDOW_SIZE_LEFT right_bound = offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT window_constraint = (kv_offs_n[None, :] >= left_bound) & (kv_offs_n[None, :] <= right_bound) @@ -650,16 +404,26 @@ def _attn_fwd_mask( tl.store(sd_mask_ptrs, p, mask=sd_store_mask) # -- update output accumulator -- - # alpha is an adjustment factor for acc and li as we loop and find new maxes - # store the diff in maxes to adjust acc and li as we discover new maxes m_diff = tl.where(m_ij == float("-inf"), float("-inf"), m_i - m_ij) if USE_EXP2: alpha = tl.math.exp2(m_diff * RCP_LN2) else: alpha = tl.math.exp(m_diff) acc = acc * alpha[:, None] + + # Load V if not preloaded if not PRE_LOAD_V: - v = tl.load(v_ptrs, mask=v_mask, other=0.0) + if APPLY_MASK: + v_mask = kv_offs_n[:, None] < seqlen_k + if PADDED_HEAD_V: + v_mask = v_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + else: + if PADDED_HEAD_V: + v_mask = offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + else: + v = tl.load(v_ptrs) # -- update m_i and l_i l_i = l_i * alpha + l_ij @@ -1050,6 +814,7 @@ def attn_fwd( FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, USE_SEQUSED: tl.constexpr, + FORCE_MASKING: tl.constexpr, ): # set params ACCUMULATOR_TYPE = tl.float32 @@ -1244,7 +1009,7 @@ def attn_fwd( block_min = n_front_skip_blocks * BLOCK_N block_max = (n_front_skip_blocks + n_front_masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_mask( + acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, @@ -1282,18 +1047,19 @@ def attn_fwd( IS_FP8, FP8_MAX, FP8_P_DESCALE, - IS_CAUSAL, - BLOCK_M, - BLOCK_DMODEL_QK, - BLOCK_DMODEL_V, - BLOCK_N, - PRE_LOAD_V, - ENABLE_DROPOUT, - PADDED_HEAD_QK, - PADDED_HEAD_V, - ACTUAL_BLOCK_DMODEL_QK, - ACTUAL_BLOCK_DMODEL_V, - SM_SCALE, + APPLY_MASK=True, # Masked blocks + IS_CAUSAL=IS_CAUSAL, + BLOCK_M=BLOCK_M, + BLOCK_DMODEL_QK=BLOCK_DMODEL_QK, + BLOCK_DMODEL_V=BLOCK_DMODEL_V, + BLOCK_N=BLOCK_N, + PRE_LOAD_V=PRE_LOAD_V, + ENABLE_DROPOUT=ENABLE_DROPOUT, + PADDED_HEAD_QK=PADDED_HEAD_QK, + PADDED_HEAD_V=PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK=ACTUAL_BLOCK_DMODEL_QK, + ACTUAL_BLOCK_DMODEL_V=ACTUAL_BLOCK_DMODEL_V, + SM_SCALE=SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, @@ -1310,7 +1076,7 @@ def attn_fwd( n_front_skip_blocks + n_front_masked_blocks + n_full_blocks ) * BLOCK_N - acc, l_i, m_i = _attn_fwd_no_mask( + acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, @@ -1340,6 +1106,7 @@ def attn_fwd( offs_d_v, block_min, # Start of range: 0 block_max, # End of range: n_full_blocks * BLOCK_N + 0, # n_extra_tokens (not used for full blocks) alibi_slope, q_descale, k_descale, @@ -1347,20 +1114,25 @@ def attn_fwd( IS_FP8, FP8_MAX, FP8_P_DESCALE, - BLOCK_M, - BLOCK_DMODEL_QK, - BLOCK_DMODEL_V, - BLOCK_N, - PRE_LOAD_V, - ENABLE_DROPOUT, - PADDED_HEAD_QK, - PADDED_HEAD_V, - ACTUAL_BLOCK_DMODEL_QK, - ACTUAL_BLOCK_DMODEL_V, - SM_SCALE, + APPLY_MASK=FORCE_MASKING, + IS_CAUSAL=IS_CAUSAL, + BLOCK_M=BLOCK_M, + BLOCK_DMODEL_QK=BLOCK_DMODEL_QK, + BLOCK_DMODEL_V=BLOCK_DMODEL_V, + BLOCK_N=BLOCK_N, + PRE_LOAD_V=PRE_LOAD_V, + ENABLE_DROPOUT=ENABLE_DROPOUT, + PADDED_HEAD_QK=PADDED_HEAD_QK, + PADDED_HEAD_V=PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK=ACTUAL_BLOCK_DMODEL_QK, + ACTUAL_BLOCK_DMODEL_V=ACTUAL_BLOCK_DMODEL_V, + SM_SCALE=SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, + USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE, ) @@ -1376,7 +1148,7 @@ def attn_fwd( + n_back_masked_blocks ) * BLOCK_N - acc, l_i, m_i = _attn_fwd_mask( + acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, @@ -1414,18 +1186,19 @@ def attn_fwd( IS_FP8, FP8_MAX, FP8_P_DESCALE, - IS_CAUSAL, # Use actual causal flag - BLOCK_M, - BLOCK_DMODEL_QK, - BLOCK_DMODEL_V, - BLOCK_N, - PRE_LOAD_V, - ENABLE_DROPOUT, - PADDED_HEAD_QK, - PADDED_HEAD_V, - ACTUAL_BLOCK_DMODEL_QK, - ACTUAL_BLOCK_DMODEL_V, - SM_SCALE, + APPLY_MASK=True, # Masked blocks + IS_CAUSAL=IS_CAUSAL, # Use actual causal flag + BLOCK_M=BLOCK_M, + BLOCK_DMODEL_QK=BLOCK_DMODEL_QK, + BLOCK_DMODEL_V=BLOCK_DMODEL_V, + BLOCK_N=BLOCK_N, + PRE_LOAD_V=PRE_LOAD_V, + ENABLE_DROPOUT=ENABLE_DROPOUT, + PADDED_HEAD_QK=PADDED_HEAD_QK, + PADDED_HEAD_V=PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK=ACTUAL_BLOCK_DMODEL_QK, + ACTUAL_BLOCK_DMODEL_V=ACTUAL_BLOCK_DMODEL_V, + SM_SCALE=SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, @@ -1438,19 +1211,15 @@ def attn_fwd( # ============================================================ # EPILOGUE # ============================================================ - # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. - # Instead of directly computing 1/l_i which can be inf, - # we check for the invalid case first - if USE_SLIDING_WINDOW: - # For rows where m_i is still -inf, no keys were valid - # Set l_i to 1.0 to avoid division by zero (acc is already 0) - invalid_mask = m_i == float("-inf") - l_i_safe = tl.where(invalid_mask, 1.0, l_i) - l_recip = 1 / l_i_safe[:, None] - else: - invalid_mask = None - l_recip = 1 / l_i[:, None] + # Handle invalid rows: rows with no valid keys to attend to. + # This occurs with sliding window or causal attention (when seqlen_q > seqlen_k). + # For invalid rows: m_i = -inf, l_i = 0, acc = 0. + # We set l_i = 1.0 to avoid division by zero and ensure LSE = -inf. + invalid_mask = m_i == float("-inf") + l_i_safe = tl.where(invalid_mask, 1.0, l_i) + l_recip = 1 / l_i_safe[:, None] acc = acc * l_recip + if ENABLE_DROPOUT: dropout_scale = 1 / (1 - dropout_p) acc = acc * dropout_scale @@ -1459,68 +1228,12 @@ def attn_fwd( if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 LN2: tl.constexpr = 0.6931471824645996 - # compute log-sum-exp in base 2 units - mi_base2 = m_i * RCP_LN2 - # For invalid rows, log(l_i) would be -inf, but we want LSE to be -inf - # So we handle this case explicitly - if USE_SLIDING_WINDOW: - log_l_i = tl.where(invalid_mask, 0.0, tl.math.log2(l_i)) - softmax_lse = mi_base2 + log_l_i - # Ensure invalid rows have LSE = -inf - softmax_lse = tl.where(invalid_mask, float("-inf"), softmax_lse) - else: - softmax_lse = mi_base2 + tl.math.log2(l_i) - # convert back to natural units - softmax_lse *= LN2 - else: - if USE_SLIDING_WINDOW: - log_l_i = tl.where(invalid_mask, 0.0, tl.math.log(l_i)) - softmax_lse = m_i + log_l_i - softmax_lse = tl.where(invalid_mask, float("-inf"), softmax_lse) - else: - softmax_lse = m_i + tl.math.log(l_i) - - # handle masking edge cases - if USE_SLIDING_WINDOW: - if IS_CAUSAL: - pass - else: - pass + softmax_lse = (m_i * RCP_LN2 + tl.math.log2(l_i)) * LN2 else: - if IS_CAUSAL: - # When seqlen_q > seqlen_k, some rows are completely above the causal diagonal - # These rows have all -inf attention scores, resulting in NaN after softmax - # e.g. - # Q length: 6, K length: 4 - # Causal mask (X = can attend, . = cannot): - # K0 K1 K2 K3 - # Q0 . . . . <- All masked, would give NaN - # Q1 . . . . <- All masked, would give NaN - # Q2 X . . . <- First valid row - # Q3 X X . . - # Q4 X X X . - # Q5 X X X X - causal_start_idx = seqlen_q - seqlen_k - start_m_idx = start_m * BLOCK_M - - # Create mask for rows that need zeroing - row_indices = start_m_idx + tl.arange(0, BLOCK_M) - causal_mask = row_indices < causal_start_idx - - # Zero out both acc and LSE for these rows - if causal_start_idx > start_m_idx: - end_m_idx = (start_m + 1) * BLOCK_M - if causal_start_idx < end_m_idx: - # This block contains the boundary - need to mask acc - out_mask_boundary = tl.full( - (BLOCK_DMODEL_V,), causal_start_idx, dtype=tl.int32 - ) - out_ptrs_mask = row_indices[:, None] >= out_mask_boundary[None, :] - z = 0.0 - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + softmax_lse = m_i + tl.math.log(l_i) - # Zero out LSE for rows above diagonal - softmax_lse = tl.where(causal_mask, 0.0, softmax_lse) + # Ensure invalid rows have LSE = -inf + softmax_lse = tl.where(invalid_mask, float("-inf"), softmax_lse) # write back LSE(Log Sum Exponents), the log of the normalization constant l_offset = ( @@ -1946,6 +1659,10 @@ def attention_forward_prefill_triton_impl( else: stride_bz, stride_bh, stride_bm, stride_bn = (0, 0, 0, 0) + # Detect if we need to force masking for all blocks (required on some architectures) + arch = get_arch() + force_masking = arch.is_rdna + # launch kernel grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META["BLOCK_M"])) attn_fwd[grid]( @@ -2022,4 +1739,5 @@ def attention_forward_prefill_triton_impl( FP8_MAX=FP8_MAX, FP8_P_DESCALE=False, USE_SEQUSED=(seqused_q is not None or seqused_k is not None), + FORCE_MASKING=force_masking, ) diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index d4cfda18a23..7743fcf4e47 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -7,8 +7,33 @@ import triton import triton.language as tl import numpy as np +from dataclasses import dataclass from typing import Literal, Optional, Union, Tuple + +# ------------------------------- +# GPU Architecture +# ------------------------------- +ArchFamily = Literal["cdna", "rdna"] + +CDNA_ARCHS = frozenset({"gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx950"}) +RDNA_ARCHS = frozenset({"gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"}) + + +@dataclass(frozen=True) +class GpuArch: + """GPU architecture information.""" + name: str # e.g., "gfx942", "gfx1100" + family: Optional[ArchFamily] = None + + @property + def is_cdna(self) -> bool: + return self.family == "cdna" + + @property + def is_rdna(self) -> bool: + return self.family == "rdna" + # ------------------------------- # Gloabl Variables # ------------------------------- @@ -1487,8 +1512,15 @@ def is_hip(): @functools.cache -def get_arch(): - return triton.runtime.driver.active.get_current_target().arch +def get_arch() -> GpuArch: + """Get the current GPU architecture.""" + name = triton.runtime.driver.active.get_current_target().arch + if name in CDNA_ARCHS: + return GpuArch(name=name, family="cdna") + elif name in RDNA_ARCHS: + return GpuArch(name=name, family="rdna") + else: + return GpuArch(name=name) @functools.cache @@ -1496,27 +1528,3 @@ def get_cu_count(): return torch.cuda.get_device_properties( torch.cuda.current_device() ).multi_processor_count - - -@functools.cache -def is_cdna(): - return is_hip() and get_arch() in ( - "gfx908", - "gfx90a", - "gfx940", - "gfx941", - "gfx942", - "gfx950", - ) - - -@functools.cache -def is_rdna(): - return is_hip() and get_arch() in ( - "gfx1030", - "gfx1100", - "gfx1101", - "gfx1102", - "gfx1200", - "gfx1201", - ) diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index b924890b115..ac1ca579d0f 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -16,7 +16,19 @@ from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb -from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_hip, is_rdna +from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_hip, get_arch + + +def _get_block_size_n_triton(device, head_dim, is_dropout, is_causal): + """Get block size for Triton AMD kernel.""" + arch = get_arch() + if arch.is_rdna: + return 32 + elif arch.is_cdna: + return 64 + # Fall back to CUDA kernel block sizes + return _get_block_size_n(device, head_dim, is_dropout, is_causal) + MAX_HEADDIM_SM8x = 192 @@ -507,7 +519,7 @@ def normalize_flash_attn_S( scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias.to(dtype=scores.dtype) - block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) + block_size_n = _get_block_size_n_triton(scores.device, head_dim, is_dropout, causal) scores_block = scores.split(block_size_n, dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) lse = torch.logsumexp(lse_block, dim=-1) @@ -1491,7 +1503,7 @@ def test_flash_attn_varlen_output( # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): if USE_TRITON_ROCM: - if is_rdna(): + if get_arch().is_rdna: if seqlen_q == 1 and seqlen_k == 239 and d == 256: pytest.skip("This config doesnot work on RDNA Devices.") if ( From 82006063cb81e26b10169412807e426a91e551b5 Mon Sep 17 00:00:00 2001 From: Michael Date: Thu, 22 Jan 2026 14:32:49 -0500 Subject: [PATCH 27/27] fix type check bug --- flash_attn/flash_attn_triton_amd/bwd.py | 26 +- flash_attn/flash_attn_triton_amd/common.py | 551 ++++++ .../flash_attn_triton_amd/fwd_decode.py | 132 +- .../flash_attn_triton_amd/fwd_prefill.py | 16 +- .../flash_attn_triton_amd/interface_v2.py | 63 +- .../flash_attn_triton_amd/interface_v3.py | 109 +- .../flash_attn_triton_amd/pyproject.toml | 48 + flash_attn/flash_attn_triton_amd/utils.py | 1526 ++--------------- 8 files changed, 897 insertions(+), 1574 deletions(-) create mode 100644 flash_attn/flash_attn_triton_amd/common.py create mode 100644 flash_attn/flash_attn_triton_amd/pyproject.toml diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 589a3ffeebf..87dc49fc9bc 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -1,14 +1,13 @@ import os import torch -import triton # type: ignore -import triton.language as tl # type: ignore +import triton +import triton.language as tl import warnings from typing import Literal, Optional +from .common import compute_fp8_scaling_factors from .utils import ( DEBUG, AUTOTUNE, - compute_fp8_scaling_factors, - get_cu_count, is_fp8, get_arch, ) @@ -48,7 +47,7 @@ def get_bwd_configs(autotune: bool): # configs for the kernels if arch.name == "gfx942": - if get_cu_count() < 304: + if arch.cu_count < 304: preprocess_configs = [ triton.Config( {"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=8 @@ -233,7 +232,7 @@ def get_bwd_configs(autotune: bool): num_warps=4, ), ] - elif arch == "gfx950": + elif arch.name == "gfx950": preprocess_configs = [ triton.Config( {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 @@ -3891,8 +3890,12 @@ def is_contiguous(x, name): return x.contiguous() -DEBUG_TRITON: bool = False -DEBUG_TRITON_DETAIL: bool = False +# Triton kernel debug flags derived from DEBUG level. +# Level 1: basic kernel debug prints (iteration info) +# Level 2: detailed kernel debug prints (tensor values) +# Requires TRITON_INTERPRET=1 to actually print inside kernels. +DEBUG_TRITON: bool = DEBUG >= 1 +DEBUG_TRITON_DETAIL: bool = DEBUG >= 2 def attention_backward_triton_impl( @@ -4131,6 +4134,11 @@ def attention_backward_triton_impl( # fp8 IS_FP8 = is_fp8([q, k, v]) if IS_FP8: + arch = get_arch() + if not arch.supports_fp8: + raise RuntimeError( + f"{arch.name} does not support FP8" + ) FP8_MAX = torch.finfo(q.dtype).max warnings.warn( @@ -4233,7 +4241,7 @@ def attention_backward_triton_impl( IS_FP8=IS_FP8, ) - if False: + if DEBUG: print("delta:", delta, delta.shape) # dropout mask tensor for debugging. We dump the dropout mask created in diff --git a/flash_attn/flash_attn_triton_amd/common.py b/flash_attn/flash_attn_triton_amd/common.py new file mode 100644 index 00000000000..2f1a209383a --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/common.py @@ -0,0 +1,551 @@ +""" +Triton kernel helper functions shared across flash attention modules. + +This module contains Triton JIT-compiled helper functions that are used within +the main attention kernels (fwd_prefill, fwd_decode, bwd). These are kept +separate from utils.py to allow stricter type checking on pure Python utilities. +""" +from typing import Literal, Optional, Tuple, Union + +import torch +import triton +import triton.language as tl + +from .utils import DEBUG, get_shape_from_layout, get_stride_from_layout, is_fp8 + + +@triton.jit +def compute_alibi_block( + alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False +): + """ + Compute ALiBi (Attention with Linear Biases) block. + + When seqlen_k and seqlen_q are different, the diagonal sticks to the + bottom right of the attention matrix. + """ + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5 + # offs_m = [0, 1], offs_n = [0, 1, 2, 3, 4] + # Result: [[-3, -2, -1, 0, -1], [-4, -3, -2, -1, 0]] + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + + +@triton.jit +def compute_fp8_scaling_factors(x, fp8_max: tl.constexpr): + """Compute FP8 scaling and descaling factors for a block.""" + x_amax = tl.max(tl.abs(x)) + x_amax = tl.where(x_amax <= 1e-9, 1e-9, x_amax) + scale_x = fp8_max / x_amax + descale_x = x_amax / fp8_max + return scale_x, descale_x + + +@triton.jit +def _cast_varlen_to_fp8_kernel_2d( + X, + X_fp8, + Descale, + cu_seqlens, + H, + MAX_SEQLEN, + stride_batch, + stride_seq, + stride_head, + stride_dim, + stride_out_batch, + stride_out_seq, + stride_out_head, + stride_out_dim, + stride_desc_batch, + stride_desc_head, + FP8_CLAMP_VAL, + FP8_MAX, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + """Cast tensor to FP8 with per-(batch, head) scaling.""" + b_id = tl.program_id(0) + h_id = tl.program_id(1) + + # Get sequence bounds for this batch + if IS_VARLEN: + seq_start = tl.load(cu_seqlens + b_id) + seq_end = tl.load(cu_seqlens + b_id + 1) + seqlen = seq_end - seq_start + else: + seq_start = 0 + seqlen = MAX_SEQLEN + + # initialize max value tracker + x_max_val = 0.0 + + # STEP 1: Find max absolute value across the entire sequence + num_of_blocks = tl.cdiv(seqlen, BLOCK_SIZE) + for blk_idx in range(0, num_of_blocks): + offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_dim = tl.arange(0, HEAD_DIM) + + mask_seq = offs_seq[:, None] < seqlen + if ACTUAL_HEAD_DIM != HEAD_DIM: + mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM + mask_seq = mask_seq & mask_dim + + adj_x = ( + b_id * stride_batch + + h_id * stride_head + + seq_start * stride_seq + + offs_seq[:, None] * stride_seq + + offs_dim[None, :] * stride_dim + ) + x_block = tl.load(X + adj_x, mask=mask_seq, other=0.0) + block_max = tl.max(tl.abs(x_block)) + x_max_val = tl.maximum(x_max_val, block_max) + + # clamp to avoid division by zero + x_max_val = tl.maximum(x_max_val, FP8_CLAMP_VAL) + + # compute scale and descale factors + scale = FP8_MAX / x_max_val + descale = x_max_val / FP8_MAX + + # store descale factor + desc_ptr = Descale + b_id * stride_desc_batch + h_id + tl.store(desc_ptr, descale) + + # STEP 2: Apply scaling and convert to FP8 + for blk_idx in range(0, num_of_blocks): + offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_dim = tl.arange(0, HEAD_DIM) + + mask_seq = offs_seq[:, None] < seqlen + if ACTUAL_HEAD_DIM != HEAD_DIM: + mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM + mask_seq = mask_seq & mask_dim + + addr = ( + b_id * stride_batch + + h_id * stride_head + + seq_start * stride_seq + + offs_seq[:, None] * stride_seq + + offs_dim[None, :] * stride_dim + ) + x_block = tl.load(X + addr, mask=mask_seq, other=0.0) + x_fp8_block = (x_block * scale).to(X_fp8.type.element_ty) + + addr_out = ( + b_id * stride_out_batch + + h_id * stride_out_head + + seq_start * stride_out_seq + + offs_seq[:, None] * stride_out_seq + + offs_dim[None, :] * stride_out_dim + ) + tl.store(X_fp8 + addr_out, x_fp8_block, mask=mask_seq) + + +@triton.jit +def _rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, + seqlen, + nheads, + seqlen_ro, + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + ROTARY_DIM: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_M: tl.constexpr, +): + """Apply rotary positional embeddings.""" + BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM) + ROTARY_DIM_HALF = ROTARY_DIM // 2 + pid_head = tl.program_id(axis=0) + pid_m = tl.program_id(axis=1) + pid_batch = tl.program_id(axis=2) + + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + OUT = OUT + pid_batch * stride_out_batch + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + OUT = OUT + start_idx * stride_out_seqlen + + if pid_m * BLOCK_M >= seqlen: + return + + rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + + rk_half = tl.arange(0, BLOCK_K // 2) + COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF) + cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32) + sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + + if not INTERLEAVED: + X = X + ( + rh[:, None, None] * stride_x_nheads + + rm[None, :, None] * stride_x_seqlen + + rk_half[None, None, :] * stride_x_headdim + ) + OUT = OUT + ( + rh[:, None, None] * stride_out_nheads + + rm[None, :, None] * stride_out_seqlen + + rk_half[None, None, :] * stride_out_headdim + ) + mask = ( + (rh[:, None, None] < nheads) + & (rm[None, :, None] < seqlen) + & (rk_half[None, None, :] < ROTARY_DIM_HALF) + ) + x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0).to( + tl.float32 + ) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + tl.store(OUT, o0, mask=mask) + tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask) + else: + rk = tl.arange(0, BLOCK_K) + X = X + ( + rh[:, None, None] * stride_x_nheads + + rm[None, :, None] * stride_x_seqlen + + rk[None, None, :] * stride_x_headdim + ) + OUT = OUT + ( + rh[:, None, None] * stride_out_nheads + + rm[None, :, None] * stride_out_seqlen + + rk[None, None, :] * stride_out_headdim + ) + mask = ( + (rh[:, None, None] < nheads) + & (rm[None, :, None] < seqlen) + & (rk[None, None, :] < ROTARY_DIM) + ) + x = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2])) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K]) + tl.store(OUT, o, mask=mask) + + +# ------------------------------- +# Python wrappers for Triton kernels +# ------------------------------- + + +def cast_to_fp8( + x: torch.Tensor, + fp8_dtype: torch.dtype, + layout: Literal["bshd", "thd"], + clamp_val: float = 1e-9, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Cast tensor to FP8 with per-(batch, head) scaling factors.""" + if DEBUG > 0: + print() + print("cast_to_fp8") + print("x:", x, x.shape) + print("fp8_dtype:", fp8_dtype) + print("cu_seqlens:", cu_seqlens) + print("max_seqlen:", max_seqlen) + print("clamp_val:", clamp_val) + + assert x.dtype in { + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + } and is_fp8(fp8_dtype), f"Cannot cast {x.dtype} to {fp8_dtype}" + + batch, max_seqlen_final, num_heads, head_dim = get_shape_from_layout( + x, layout, cu_seqlens, max_seqlen + ) + is_varlen = layout == "thd" + fp8_max = torch.finfo(fp8_dtype).max + + padded_head_dim = 1 << (head_dim - 1).bit_length() + padded_head_dim = max(padded_head_dim, 32) + + x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) + descale_factors = torch.zeros( + (batch, num_heads), device=x.device, dtype=torch.float32 + ) + BLOCK_SIZE = 128 + + stride_batch, stride_head, stride_seq, stride_dim = get_stride_from_layout(x, layout) + stride_out_batch, stride_out_head, stride_out_seq, stride_out_dim = get_stride_from_layout(x_fp8, layout) + stride_desc_batch, stride_desc_head = descale_factors.stride() + + grid = (batch, num_heads) + _cast_varlen_to_fp8_kernel_2d[grid]( + x, + x_fp8, + descale_factors, + cu_seqlens, + num_heads, + max_seqlen_final, + stride_batch, + stride_seq, + stride_head, + stride_dim, + stride_out_batch, + stride_out_seq, + stride_out_head, + stride_out_dim, + stride_desc_batch, + stride_desc_head, + clamp_val, + fp8_max, + BLOCK_SIZE=BLOCK_SIZE, + HEAD_DIM=padded_head_dim, + ACTUAL_HEAD_DIM=head_dim, + IS_VARLEN=is_varlen, + ) + + return x_fp8, descale_factors + + +def _apply_rotary_kernel( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + """Apply rotary positional embeddings using Triton kernel.""" + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed, max_seqlen must also be provided" + total_seqlen, nheads, headdim = x.shape + assert cu_seqlens is not None + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim_half = cos.shape + assert sin.shape == cos.shape + rotary_dim = 2 * rotary_dim_half + assert rotary_dim <= headdim + assert headdim <= 256 + assert seqlen_ro >= seqlen + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in (torch.int32, torch.int64) + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + out = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_M = 8 if rotary_dim <= 128 else 4 + grid = ( + triton.cdiv(nheads, 2), + triton.cdiv(seqlen, BLOCK_M), + batch, + ) + + with torch.cuda.device(x.device.index): + torch.library.wrap_triton(_rotary_kernel)[grid]( + out, + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, + nheads, + seqlen_ro, + out.stride(0) if not is_varlen else 0, + out.stride(-3), + out.stride(-2), + out.stride(-1), + x.stride(0) if not is_varlen else 0, + x.stride(-3), + x.stride(-2), + x.stride(-1), + rotary_dim, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M=BLOCK_M, + BLOCK_H=2, + ) + return out + + +class _ApplyRotary(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool, + inplace: bool, + seqlen_offsets: Union[int, torch.Tensor], + cu_seqlens: Optional[torch.Tensor], + max_seqlen: Optional[int], + ) -> torch.Tensor: + out = _apply_rotary_kernel( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + conjugate=False, + ) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do: torch.Tensor) -> tuple[torch.Tensor, None, None, None, None, None, None, None]: + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + dx = _apply_rotary_kernel( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False, + inplace: bool = False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +) -> torch.Tensor: + """Apply rotary embeddings to tensor x. + + Args: + x: (B, S, H, D) if `cu_seqlens` is None else (total_S, H, D). + cos, sin: (S_rotary, rotary_dim/2) + interleaved: GPT-J style if True. + inplace: modify x in place. + seqlen_offsets: int or (B,) tensor of starting offsets per sequence. + cu_seqlens: (B+1,) tensor enabling varlen mode. + max_seqlen: required when `cu_seqlens` is provided. + """ + original_dtype = x.dtype + is_fp8_input = original_dtype == getattr(torch, "float8_e4m3fn", None) + if is_fp8_input: + target_dtype = ( + torch.bfloat16 + if cos.dtype == torch.bfloat16 or torch.cuda.is_bf16_supported() + else torch.float16 + ) + x_up = x.to(target_dtype) + cos_up = cos.to(target_dtype) if cos.dtype != target_dtype else cos + sin_up = sin.to(target_dtype) if sin.dtype != target_dtype else sin + out_up = _ApplyRotary.apply( + x_up, cos_up, sin_up, interleaved, False, seqlen_offsets, cu_seqlens, max_seqlen + ) + if inplace: + x.copy_(out_up.to(original_dtype)) + return x + return out_up.to(original_dtype) + else: + return _ApplyRotary.apply( + x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen + ) + + +def apply_rotary( + q: torch.Tensor, + k_new: Optional[torch.Tensor], + cos: torch.Tensor, + sin: torch.Tensor, + *, + causal: bool, + local: bool, + interleaved: bool = False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Apply rotary embeddings to q and optionally k_new. + + Policy: + - If causal OR local attention: apply rotary directly on (B, S, H, D). + - Else (non-causal global): flatten heads into sequence, apply, unflatten. + - k_new is always rotated directly when provided. + """ + assert q.ndim == 4, f"Expected q shape (B,S,H,D), got {q.shape}" + B, S, H, D = q.shape + use_flatten = (not causal) and (not local) + + if use_flatten: + q_flat = q.reshape(B, S * H, D).unsqueeze(1) + q_flat = apply_rotary_emb(q_flat, cos, sin, interleaved=interleaved, seqlen_offsets=seqlen_offsets) + q = q_flat.view(B, 1, S * H, D).reshape(B, S, H, D) + else: + q = apply_rotary_emb(q, cos, sin, interleaved=interleaved, seqlen_offsets=seqlen_offsets) + + if k_new is not None: + k_new = apply_rotary_emb(k_new, cos, sin, interleaved=interleaved, seqlen_offsets=seqlen_offsets) + return q, k_new diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index f31e28ffb89..4581b3f61d8 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -4,15 +4,15 @@ import triton import triton.language as tl from typing import Literal, Optional +from .common import apply_rotary from .utils import ( DEBUG, AUTOTUNE, get_arch, get_padded_headsize, - get_shape_and_strides_from_layout, - apply_rotary, + get_shape_from_layout, + get_stride_from_layout, is_fp8, - get_recommended_fp8_dtype, ) @@ -878,6 +878,10 @@ def attention_forward_decode_triton_impl( rotary_interleaved: bool = False, seqlens_rotary: Optional[torch.Tensor] = None, ): + # Validate layout at entry + if layout != "bshd": + raise ValueError(f"{layout} layout is not supported, only 'bshd' is supported") + # apply rotary embedding if rotary_cos is not None and rotary_sin is not None: # Prefer explicitly provided rotary sequence start offsets if given; fall back to cache_seqlens. @@ -968,16 +972,11 @@ def attention_forward_decode_triton_impl( use_cache_seqlens = cache_seqlens is not None use_sliding_window = window_size_left != -1 or window_size_right != -1 use_block_table = block_table is not None - SPLIT_K = None NUM_QUANT_GROUPS = 1 # get shapes and strides - (batch_size, seqlen_q, nheads_q, dim_q), ( - stride_qz, - stride_qh, - stride_qm, - stride_qd, - ) = get_shape_and_strides_from_layout(q, layout) + batch_size, seqlen_q, nheads_q, dim_q = get_shape_from_layout(q, layout) + stride_qz, stride_qh, stride_qm, stride_qd = get_stride_from_layout(q, layout) # Handle paged KV cache layout if use_block_table: @@ -989,6 +988,7 @@ def attention_forward_decode_triton_impl( seqlen_kc = int(cache_seqlens.max().item()) else: # Infer from block_table shape [batch_size, num_blocks_per_seq] + assert block_table is not None num_blocks_per_seq = block_table.shape[1] seqlen_kc = num_blocks_per_seq * block_size_k seqlen_vc = seqlen_kc @@ -1004,66 +1004,38 @@ def attention_forward_decode_triton_impl( stride_vc_h = v_cache.stride(2) stride_vc_d = v_cache.stride(3) else: - (_, seqlen_kc, nheads_kc, dim_kc), ( - stride_kc_z, - stride_kc_h, - stride_kc_n, - stride_kc_d, - ) = get_shape_and_strides_from_layout(k_cache, layout) - (_, seqlen_vc, nheads_vc, dim_vc), ( - stride_vc_z, - stride_vc_h, - stride_vc_n, - stride_vc_d, - ) = get_shape_and_strides_from_layout(v_cache, layout) + _, seqlen_kc, nheads_kc, dim_kc = get_shape_from_layout(k_cache, layout) + stride_kc_z, stride_kc_h, stride_kc_n, stride_kc_d = get_stride_from_layout(k_cache, layout) + _, seqlen_vc, nheads_vc, dim_vc = get_shape_from_layout(v_cache, layout) + stride_vc_z, stride_vc_h, stride_vc_n, stride_vc_d = get_stride_from_layout(v_cache, layout) block_size_k = 0 # Not used if is_new_kv: - (_, seqlen_kn, nheads_kn, dim_kn), ( - stride_kn_z, - stride_kn_h, - stride_kn_n, - stride_kn_d, - ) = get_shape_and_strides_from_layout(k_new, layout) - (_, seqlen_vn, nheads_vn, dim_vn), ( - stride_vn_z, - stride_vn_h, - stride_vn_n, - stride_vn_d, - ) = get_shape_and_strides_from_layout(v_new, layout) + _, seqlen_kn, nheads_kn, dim_kn = get_shape_from_layout(k_new, layout) + stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d = get_stride_from_layout(k_new, layout) + _, seqlen_vn, nheads_vn, dim_vn = get_shape_from_layout(v_new, layout) + stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d = get_stride_from_layout(v_new, layout) else: - (_, seqlen_kn, nheads_kn, dim_kn), ( - stride_kn_z, - stride_kn_h, - stride_kn_n, - stride_kn_d, - ) = (None, None, None, None,), (None, None, None, None) - (_, seqlen_vn, nheads_vn, dim_vn), ( - stride_vn_z, - stride_vn_h, - stride_vn_n, - stride_vn_d, - ) = (None, None, None, None,), (None, None, None, None) - (_, seqlen_o, nheads_o, dim_o), (stride_oz, stride_oh, stride_om, stride_od) = ( - get_shape_and_strides_from_layout(out, layout) - ) + _, seqlen_kn, nheads_kn, dim_kn = None, None, None, None + stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d = None, None, None, None + _, seqlen_vn, nheads_vn, dim_vn = None, None, None, None + stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d = None, None, None, None + _, seqlen_o, nheads_o, dim_o = get_shape_from_layout(out, layout) + stride_oz, stride_oh, stride_om, stride_od = get_stride_from_layout(out, layout) assert ( dim_q == dim_kc == dim_vc ), f"Dimensions must match: {dim_q}, {dim_kc}, {dim_vc}" # add extra information needed by the kernels - if layout == "bshd": - (n_group_q, heads_per_group_q), stride_qg = (1, nheads_q), stride_qm - (n_group_k, heads_per_group_k), stride_kc_g = (1, nheads_kc), stride_kc_n - (n_group_v, heads_per_group_v), stride_vc_g = (1, nheads_vc), stride_vc_n - if is_new_kv: - (n_group_kn, heads_per_group_kn), stride_kn_g = (1, nheads_kn), stride_kn_n - (n_group_vn, heads_per_group_vn), stride_vn_g = (1, nheads_vn), stride_vn_n - else: - (n_group_kn, heads_per_group_kn), stride_kn_g = (None, None), None - (n_group_vn, heads_per_group_vn), stride_vn_g = (None, None), None - (n_group_o, heads_per_group_o), stride_og = (1, nheads_o), stride_om + (n_group_q, heads_per_group_q), stride_qg = (1, nheads_q), stride_qm + (n_group_k, heads_per_group_k), stride_kc_g = (1, nheads_kc), stride_kc_n + (n_group_v, heads_per_group_v), stride_vc_g = (1, nheads_vc), stride_vc_n + if is_new_kv: + (n_group_kn, heads_per_group_kn), stride_kn_g = (1, nheads_kn), stride_kn_n + (n_group_vn, heads_per_group_vn), stride_vn_g = (1, nheads_vn), stride_vn_n else: - raise ValueError(f"{layout} layout is not supported") + (n_group_kn, heads_per_group_kn), stride_kn_g = (None, None), None + (n_group_vn, heads_per_group_vn), stride_vn_g = (None, None), None + (n_group_o, heads_per_group_o), stride_og = (1, nheads_o), stride_om # get padded size dim_padded = get_padded_headsize(dim_kc) @@ -1076,20 +1048,17 @@ def attention_forward_decode_triton_impl( else: is_gqa = False - if SPLIT_K is not None: - split_k = SPLIT_K + # Use heuristics for split_k + if use_block_table: + # For paged attention, use the actual sequence length from cache_seqlens + max_seqlen = ( + int(cache_seqlens.max().item()) + if cache_seqlens is not None + else block_size_k + ) + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, max_seqlen) else: - # Use heuristics - if use_block_table: - # For paged attention, use the actual sequence length from cache_seqlens - max_seqlen = ( - int(cache_seqlens.max().item()) - if cache_seqlens is not None - else block_size_k - ) - split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, max_seqlen) - else: - split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) split_size = (seqlen_kc + split_k - 1) // split_k # setup grid - use lambda to get BLOCK_M from autotune @@ -1141,6 +1110,7 @@ def attention_forward_decode_triton_impl( # Block table strides if use_block_table: + assert block_table is not None stride_bt_b, stride_bt_s = block_table.stride() else: stride_bt_b, stride_bt_s = 0, 0 @@ -1148,13 +1118,17 @@ def attention_forward_decode_triton_impl( # FP8 support IS_FP8 = is_fp8([q, k_cache, v_cache]) if IS_FP8: - rec_dtype = get_recommended_fp8_dtype(q) + arch = get_arch() + if not arch.supports_fp8: + raise RuntimeError( + f"{arch.name} does not support FP8" + ) + rec_dtype = arch.recommended_fp8_dtype(q.dtype) if ( q.dtype != rec_dtype or k_cache.dtype != rec_dtype or v_cache.dtype != rec_dtype ): - arch = get_arch() warnings.warn( f"Use {rec_dtype} data type on {arch}. Got q: {q.dtype}, k: {k_cache.dtype}, v: {v_cache.dtype}", UserWarning, @@ -1360,15 +1334,15 @@ def attention_forward_decode_triton_impl( k_block_num = 2 assert dim_padded % k_block_num == 0 k_block_size = dim_padded // k_block_num - grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num) + reduce_grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num) if DEBUG: print("splitK_pow2:", splitK_pow2) print("k_block_num:", k_block_num) print("k_block_size:", k_block_size) - print("grid:", grid) + print("grid:", reduce_grid) - _splitK_reduce[grid]( + _splitK_reduce[reduce_grid]( out_splitk, metadata, out, diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index ff2b3d71131..ef8a9d5ff45 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -4,16 +4,12 @@ import triton import triton.language as tl from typing import Literal, Optional +from .common import compute_alibi_block, compute_fp8_scaling_factors, apply_rotary from .utils import ( DEBUG, AUTOTUNE, - compute_alibi_block, - compute_fp8_scaling_factors, get_arch, - get_cu_count, is_fp8, - apply_rotary, - get_recommended_fp8_dtype, ) @@ -53,7 +49,7 @@ def get_fwd_prefill_configs(autotune: bool): ) ] elif arch.name == "gfx942": - if get_cu_count() < 304: + if arch.cu_count < 304: return [ triton.Config( { @@ -1534,10 +1530,14 @@ def attention_forward_prefill_triton_impl( # fp8 setup and assertions IS_FP8 = is_fp8([q, k, v]) if IS_FP8: + arch = get_arch() + if not arch.supports_fp8: + raise RuntimeError( + f"{arch.name} does not support FP8" + ) FP8_MAX = torch.finfo(q.dtype).max - rec_dtype = get_recommended_fp8_dtype(q) + rec_dtype = arch.recommended_fp8_dtype(q.dtype) if q.dtype != rec_dtype or k.dtype != rec_dtype or v.dtype != rec_dtype: - arch = get_arch() warnings.warn( f"Use {rec_dtype} data type on {arch}. Got q: {q.dtype}, k: {k.dtype}, v: {v.dtype}", UserWarning, diff --git a/flash_attn/flash_attn_triton_amd/interface_v2.py b/flash_attn/flash_attn_triton_amd/interface_v2.py index d4803a2c252..e0669779be4 100644 --- a/flash_attn/flash_attn_triton_amd/interface_v2.py +++ b/flash_attn/flash_attn_triton_amd/interface_v2.py @@ -1,6 +1,6 @@ import torch import os -from typing import Optional, Union +from typing import Literal, Optional, Union from .fwd_prefill import attention_forward_prefill_triton_impl from .fwd_decode import attention_forward_decode_triton_impl from .bwd import attention_backward_triton_impl @@ -12,7 +12,6 @@ PHILOX_OFFSET, SHAPE_EXPECTATIONS, round_multiple, - tensor_stats, ) @@ -30,7 +29,7 @@ def fwd( softcap: float, return_softmax: bool, gen_: Optional[torch.Tensor] = None, -): +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]: # Reject FP8 tensors (FA2 AMD path does not support FP8) if str(q.dtype).startswith("torch.float8"): @@ -47,11 +46,11 @@ def fwd( if DEBUG: print() print("flash_attn_triton_amd.py::fwd inputs") - print(tensor_stats("q", q)) - print(tensor_stats("k", k)) - print(tensor_stats("v", v)) - print(tensor_stats("out", out)) - print(tensor_stats("alibi_slopes", alibi_slopes)) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out.shape if out is not None else None) + print("alibi_slopes:", alibi_slopes.shape if alibi_slopes is not None else None) print("dropout_p:", dropout_p) print("softmax_scale:", softmax_scale) print("causal:", causal) @@ -66,7 +65,7 @@ def fwd( out.zero_() # Layout / shapes - layout = "bshd" + layout: Literal["bshd", "bhsd", "thd"] = "bshd" max_seqlen_q = q.shape[1] max_seqlen_k = k.shape[1] batch, _, nheads_q, _ = q.shape @@ -162,9 +161,9 @@ def fwd( if DEBUG: print("flash_attn_triton_amd.py::fwd outputs") - print(tensor_stats("out", out)) - print(tensor_stats("softmax_lse", softmax_lse)) - print(tensor_stats("sd_mask", sd_mask)) + print("out:", out.shape) + print("softmax_lse:", softmax_lse.shape) + print("sd_mask:", sd_mask.shape if sd_mask is not None else None) print("rng_state:", rng_state) # --- Assertions (shape + dtype contracts) --- @@ -227,7 +226,7 @@ def bwd( deterministic: bool, gen_: Optional[torch.Tensor] = None, rng_state: Optional[torch.Tensor] = None, -): +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if softcap != 0.0: raise NotImplementedError( "softcap is not supported in the AMD Triton FA2 interface (expected 0.0)." @@ -245,16 +244,16 @@ def bwd( if DEBUG: print() print("flash_attn_triton_amd.py::bwd inputs") - print(tensor_stats("dout", dout)) - print(tensor_stats("q", q)) - print(tensor_stats("k", k)) - print(tensor_stats("v", v)) - print(tensor_stats("out", out)) - print(tensor_stats("softmax_lse", softmax_lse)) - print(tensor_stats("dq", dq)) - print(tensor_stats("dk", dk)) - print(tensor_stats("dv", dv)) - print(tensor_stats("alibi_slopes", alibi_slopes)) + print("dout:", dout.shape) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out.shape) + print("softmax_lse:", softmax_lse.shape) + print("dq:", dq.shape if dq is not None else None) + print("dk:", dk.shape if dk is not None else None) + print("dv:", dv.shape if dv is not None else None) + print("alibi_slopes:", alibi_slopes.shape if alibi_slopes is not None else None) print("dropout_p:", dropout_p) print("softmax_scale:", softmax_scale) print("causal:", causal) @@ -330,9 +329,9 @@ def bwd( if DEBUG: print("flash_attn_triton_amd.py::bwd outputs") - print(tensor_stats("dq", dq)) - print(tensor_stats("dk", dk)) - print(tensor_stats("dv", dv)) + print("dq:", dq.shape) + print("dk:", dk.shape) + print("dv:", dv.shape) # --- Assertions --- assert dq.shape == q.shape, f"[bwd] dq shape {dq.shape} != q shape {q.shape}" assert dk.shape == k.shape, f"[bwd] dk shape {dk.shape} != k shape {k.shape}" @@ -370,7 +369,7 @@ def varlen_fwd( softcap: float, return_softmax: bool, gen_: Optional[torch.Tensor] = None, -): +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]: if str(q.dtype).startswith("torch.float8"): raise NotImplementedError( @@ -414,7 +413,7 @@ def varlen_fwd( out = torch.zeros_like(q) if out is None else out.zero_() # Layout and basic info for varlen - layout = "thd" + layout: Literal["bshd", "bhsd", "thd"] = "thd" batch = len(cu_seqlens_q) - 1 total_q, nheads_q, _ = q.shape @@ -565,7 +564,7 @@ def varlen_bwd( deterministic: bool, gen_: Optional[torch.Tensor] = None, rng_state: Optional[torch.Tensor] = None, -): +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if str(q.dtype).startswith("torch.float8"): raise NotImplementedError( "FP8 tensors are not supported in the AMD Triton FA2 interface (varlen_bwd). Use the FA3 path instead." @@ -690,7 +689,7 @@ def fwd_kvcache( v_cache: torch.Tensor, k: Optional[torch.Tensor], v: Optional[torch.Tensor], - cache_seqlens: Optional[Union[(int, torch.Tensor)]], + cache_seqlens: Optional[Union[int, torch.Tensor]], rotary_cos: Optional[torch.Tensor], rotary_sin: Optional[torch.Tensor], cache_batch_idx: Optional[torch.Tensor], @@ -705,7 +704,7 @@ def fwd_kvcache( softcap: float, rotary_interleaved: bool, num_splits: int, -): +) -> tuple[torch.Tensor, torch.Tensor]: if softcap != 0.0: raise NotImplementedError( @@ -744,7 +743,7 @@ def fwd_kvcache( out = torch.zeros_like(q) if out is None else out.zero_() # Basic layout info for decode path - layout = "bshd" + layout: Literal["bshd"] = "bshd" max_seqlen_q = q.shape[1] max_seqlen_k = k_cache.shape[1] cache_seqlens_tensor = ( diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index 077bb58319c..c38c190ac35 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -1,7 +1,7 @@ import os import warnings import torch -from typing import Optional, Union, Tuple +from typing import Literal, Optional, Union, Tuple from .fwd_prefill import attention_forward_prefill_triton_impl from .fwd_decode import attention_forward_decode_triton_impl from .bwd import attention_backward_triton_impl @@ -12,8 +12,6 @@ PHILOX_SEED, PHILOX_OFFSET, is_fp8, - get_recommended_fp8_dtype, - tensor_stats, ) @@ -48,9 +46,9 @@ def fwd( attention_chunk: int, softcap: float, rotary_interleaved: bool, - scheduler_metadata=None, + scheduler_metadata: None = None, num_splits: int = 1, - pack_gqa=None, + pack_gqa: Optional[bool] = None, sm_margin: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ @@ -62,29 +60,29 @@ def fwd( if DEBUG: print() print("interface_fa_v3.py::fwd inputs") - print(tensor_stats("q", q)) - print(tensor_stats("k", k)) - print(tensor_stats("v", v)) - print(tensor_stats("k_new", k_new)) - print(tensor_stats("v_new", v_new)) - print(tensor_stats("qv", qv)) - print(tensor_stats("out", out)) - print(tensor_stats("cu_seqlens_q", cu_seqlens_q)) - print(tensor_stats("cu_seqlens_k", cu_seqlens_k)) - print(tensor_stats("cu_seqlens_k_new", cu_seqlens_k_new)) - print(tensor_stats("seqused_q", seqused_q)) - print(tensor_stats("seqused_k", seqused_k)) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("k_new:", k_new.shape if k_new is not None else None) + print("v_new:", v_new.shape if v_new is not None else None) + print("qv:", qv.shape if qv is not None else None) + print("out:", out.shape if out is not None else None) + print("cu_seqlens_q:", cu_seqlens_q.shape if cu_seqlens_q is not None else None) + print("cu_seqlens_k:", cu_seqlens_k.shape if cu_seqlens_k is not None else None) + print("cu_seqlens_k_new:", cu_seqlens_k_new.shape if cu_seqlens_k_new is not None else None) + print("seqused_q:", seqused_q.shape if seqused_q is not None else None) + print("seqused_k:", seqused_k.shape if seqused_k is not None else None) print("max_seqlen_q:", max_seqlen_q) print("max_seqlen_k:", max_seqlen_k) - print(tensor_stats("page_table", page_table)) - print(tensor_stats("kv_batch_idx", kv_batch_idx)) - print(tensor_stats("leftpad_k", leftpad_k)) - print(tensor_stats("rotary_cos", rotary_cos)) - print(tensor_stats("rotary_sin", rotary_sin)) - print(tensor_stats("seqlens_rotary", seqlens_rotary)) - print(tensor_stats("q_descale", q_descale)) - print(tensor_stats("k_descale", k_descale)) - print(tensor_stats("v_descale", v_descale)) + print("page_table:", page_table.shape if page_table is not None else None) + print("kv_batch_idx:", kv_batch_idx.shape if kv_batch_idx is not None else None) + print("leftpad_k:", leftpad_k.shape if leftpad_k is not None else None) + print("rotary_cos:", rotary_cos.shape if rotary_cos is not None else None) + print("rotary_sin:", rotary_sin.shape if rotary_sin is not None else None) + print("seqlens_rotary:", seqlens_rotary.shape if seqlens_rotary is not None else None) + print("q_descale:", q_descale.shape if q_descale is not None else None) + print("k_descale:", k_descale.shape if k_descale is not None else None) + print("v_descale:", v_descale.shape if v_descale is not None else None) print("softmax_scale:", softmax_scale) print("causal:", causal) print("window_size_left:", window_size_left) @@ -157,15 +155,21 @@ def fwd( raise ValueError( f"cu_seqlens_q provided but q has shape {q.shape}, expected 3D tensor for varlen" ) - layout = "thd" + layout: Literal["bshd", "thd"] = "thd" cu_seqlens_q_local = cu_seqlens_q + assert max_seqlen_q is not None, "max_seqlen_q required for varlen mode" max_seqlens_q_local = max_seqlen_q if cu_seqlens_k is not None: cu_seqlens_k_local = cu_seqlens_k + assert max_seqlen_k is not None, "max_seqlen_k required when cu_seqlens_k provided" max_seqlens_k_local = max_seqlen_k else: cu_seqlens_k_local = None - max_seqlens_k_local = k.shape[1] if len(k.shape) == 4 else max_seqlen_k + if len(k.shape) == 4: + max_seqlens_k_local = k.shape[1] + else: + assert max_seqlen_k is not None, "max_seqlen_k required for varlen mode" + max_seqlens_k_local = max_seqlen_k else: layout = "bshd" cu_seqlens_q_local = None @@ -254,6 +258,8 @@ def fwd( (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 ) + # Decode only supports bshd layout + assert layout == "bshd", f"decode requires bshd layout, got {layout}" attention_forward_decode_triton_impl( q, k, @@ -336,8 +342,8 @@ def fwd( if DEBUG: print("interface_fa_v3.py::fwd outputs") - print(tensor_stats("out", out)) - print(tensor_stats("softmax_lse", softmax_lse)) + print("out:", out.shape) + print("softmax_lse:", softmax_lse.shape) # --- Assertions (FA3 always expects exact shapes) --- # out: same shape as q except last dim is v's head_dim @@ -372,6 +378,7 @@ def fwd( softmax_lse.dtype == torch.float32 ), f"[fwd_v3] softmax_lse dtype {softmax_lse.dtype} != torch.float32" # softmax_lse shape depends on layout + expected_lse_shape: tuple[int, ...] if layout == "thd": # varlen: (Hq, Total_Q) expected_lse_shape = (q.shape[1], q.shape[0]) @@ -421,19 +428,19 @@ def bwd( if DEBUG: print() print("interface_fa_v3.py::bwd inputs") - print(tensor_stats("dout", dout)) - print(tensor_stats("q", q)) - print(tensor_stats("k", k)) - print(tensor_stats("v", v)) - print(tensor_stats("out", out)) - print(tensor_stats("softmax_lse", softmax_lse)) - print(tensor_stats("dq", dq)) - print(tensor_stats("dk", dk)) - print(tensor_stats("dv", dv)) - print(tensor_stats("cu_seqlens_q", cu_seqlens_q)) - print(tensor_stats("cu_seqlens_k", cu_seqlens_k)) - print(tensor_stats("seqused_q", seqused_q)) - print(tensor_stats("seqused_k", seqused_k)) + print("dout:", dout.shape) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out.shape) + print("softmax_lse:", softmax_lse.shape) + print("dq:", dq.shape if dq is not None else None) + print("dk:", dk.shape if dk is not None else None) + print("dv:", dv.shape if dv is not None else None) + print("cu_seqlens_q:", cu_seqlens_q.shape if cu_seqlens_q is not None else None) + print("cu_seqlens_k:", cu_seqlens_k.shape if cu_seqlens_k is not None else None) + print("seqused_q:", seqused_q.shape if seqused_q is not None else None) + print("seqused_k:", seqused_k.shape if seqused_k is not None else None) print("max_seqlen_q:", max_seqlen_q) print("max_seqlen_k:", max_seqlen_k) print("softmax_scale:", softmax_scale) @@ -475,6 +482,7 @@ def bwd( dv = torch.zeros_like(v, dtype=grad_dtype) if dv is None else dv.zero_() # Determine layout based on cu_seqlens + layout: Literal["bshd", "bhsd", "thd"] if cu_seqlens_q is not None and cu_seqlens_k is not None: # Variable length sequence mode layout = "thd" @@ -531,10 +539,10 @@ def bwd( if DEBUG: print("interface_fa_v3.py::bwd outputs") - print(tensor_stats("dq", dq)) - print(tensor_stats("dk", dk)) - print(tensor_stats("dv", dv)) - print(tensor_stats("delta", delta)) + print("dq:", dq.shape) + print("dk:", dk.shape) + print("dv:", dv.shape) + print("delta:", delta.shape) # --- Assertions (FA3 always expects exact shapes) --- # Gradients should match input shapes @@ -545,6 +553,7 @@ def bwd( assert ( delta.dtype == torch.float32 ), f"[bwd_v3] delta dtype {delta.dtype} != torch.float32" + expected_delta_shape: tuple[int, ...] if layout == "thd": # varlen: (Hq, Total_Q) expected_delta_shape = (q.shape[1], q.shape[0]) @@ -557,7 +566,7 @@ def bwd( # V3 expects (softmax_d, *rest) # delta is the softmax_d in this case - return delta + return (delta,) def fwd_combine( @@ -565,7 +574,7 @@ def fwd_combine( lse_partial: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = None, -) -> torch.Tensor: +) -> "torch.Tensor": """ Combine partial outputs from split attention computation. @@ -610,7 +619,7 @@ def get_scheduler_metadata( num_splits: int = 0, pack_gqa: Optional[bool] = None, sm_margin: int = 0, -): +) -> None: """ Get scheduler metadata for optimized kernel selection. diff --git a/flash_attn/flash_attn_triton_amd/pyproject.toml b/flash_attn/flash_attn_triton_amd/pyproject.toml new file mode 100644 index 00000000000..3a07ef28ed9 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/pyproject.toml @@ -0,0 +1,48 @@ +# mypy --config-file flash_attn/flash_attn_triton_amd/pyproject.toml +[tool.mypy] +files = [ + # Core Triton AMD backend + "flash_attn/flash_attn_triton_amd", + # Tests (based on test_flash_attn.py - looser rules, but catches import errors) + "tests/test_flash_attn_triton_amd.py", + "hopper/test_flash_attn_triton_amd.py", +] +ignore_missing_imports = true +follow_imports = "skip" +python_version = "3.9" + +# Strict checks +strict_equality = true +warn_unreachable = true +warn_redundant_casts = true +warn_unused_ignores = true +check_untyped_defs = true +warn_return_any = true +warn_unused_configs = true +no_implicit_optional = true +strict_optional = true +disallow_incomplete_defs = false # Triton kernels can't be fully typed +disallow_subclassing_any = false # torch.autograd.Function has type Any + +# Triton kernels use untyped decorators and defs +disallow_untyped_defs = false +disallow_untyped_decorators = false +disallow_untyped_calls = false + +# Follow imports for our module so test imports are validated +[[tool.mypy.overrides]] +module = ["flash_attn.flash_attn_triton_amd", "flash_attn.flash_attn_triton_amd.*"] +follow_imports = "normal" + +# Stricter settings for interface and utility modules only +[[tool.mypy.overrides]] +module = ["flash_attn.flash_attn_triton_amd.interface_v2", "flash_attn.flash_attn_triton_amd.interface_v3", "flash_attn.flash_attn_triton_amd.utils"] +disallow_incomplete_defs = true +disallow_untyped_defs = true + +# Test files - based on test_flash_attn.py, looser rules but catches import/export errors +[[tool.mypy.overrides]] +module = ["test_flash_attn_triton_amd", "hopper.test_flash_attn_triton_amd"] +strict_optional = false +check_untyped_defs = false + diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 7743fcf4e47..358467157c7 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -1,14 +1,43 @@ -import csv -import math -import torch -import os -import random +""" +Utilities for Flash Attention Triton AMD backend. + +This module contains essential runtime utilities: +- GPU architecture detection +- Global configuration flags +- Tensor shape/stride helpers +- FP8 type detection +""" import functools -import triton -import triton.language as tl -import numpy as np +import os from dataclasses import dataclass -from typing import Literal, Optional, Union, Tuple +from typing import Literal, Optional, Union + +import torch +import triton + + +__all__ = [ + # Runtime info + "get_arch", + "is_hip", + # Global config + "AUTOTUNE", + "DEBUG", + "USE_TRITON_ROCM", + "BWD_MODE", + "USE_EXP2", + "PHILOX_SEED", + "PHILOX_OFFSET", + "SHAPE_EXPECTATIONS", + # FP8 + "is_fp8", + # Shape/stride helpers + "get_shape_from_layout", + "get_stride_from_layout", + "get_padded_headsize", + # Misc helpers + "round_multiple", +] # ------------------------------- @@ -18,6 +47,14 @@ CDNA_ARCHS = frozenset({"gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx950"}) RDNA_ARCHS = frozenset({"gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"}) +FP8_ARCHS = frozenset({"gfx942", "gfx950"}) + +_RECOMMENDED_FP8_REPLACEMENTS: dict[str, dict[torch.dtype, torch.dtype]] = { + "gfx942": { + torch.float8_e4m3fn: torch.float8_e4m3fnuz, + torch.float8_e5m2: torch.float8_e5m2fnuz, + }, +} @dataclass(frozen=True) @@ -34,35 +71,50 @@ def is_cdna(self) -> bool: def is_rdna(self) -> bool: return self.family == "rdna" + @property + def supports_fp8(self) -> bool: + """Check if this architecture supports FP8.""" + return self.name in FP8_ARCHS + + def recommended_fp8_dtype(self, dtype: torch.dtype) -> torch.dtype: + """Get the recommended FP8 dtype for this architecture. + + Some architectures prefer different FP8 variants (e.g., fnuz vs fn). + Returns the input dtype unchanged if no replacement is recommended. + """ + return _RECOMMENDED_FP8_REPLACEMENTS.get(self.name, {}).get(dtype, dtype) + + @property + def cu_count(self) -> int: + """Get the number of compute units on the current GPU.""" + return int( + torch.cuda.get_device_properties( + torch.cuda.current_device() + ).multi_processor_count + ) + + # ------------------------------- -# Gloabl Variables +# Global Variables # ------------------------------- +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" AUTOTUNE = os.environ.get("FLASH_ATTENTION_TRITON_AMD_AUTOTUNE", "0").lower() in ( "1", "true", "yes", ) -DEBUG = os.environ.get("FLASH_ATTENTION_TRITON_AMD_DEBUG", "0").lower() in ( - "1", - "true", - "yes", -) -if AUTOTUNE or DEBUG: + +# Unified debug level: +# 0 = off (default) +# 1 = basic debug info (shapes, tensor stats, kernel params) +# 2 = detailed debug (includes Triton interpreter prints in kernels) +# +# Set via: FLASH_ATTENTION_TRITON_AMD_DEBUG=0|1|2 +DEBUG: int = int(os.environ.get("FLASH_ATTENTION_TRITON_AMD_DEBUG", "0")) +if AUTOTUNE or DEBUG > 0: os.environ["TRITON_PRINT_AUTOTUNING"] = "1" -USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" -USE_TRITON_INTERPRET = os.environ.get("TRITON_INTERPRET", "0").lower() in ( - "1", - "true", - "yes", -) -DEBUG_TRITON = ( - os.environ.get("DEBUG_TRITON", "0").lower() in ("1", "true", "yes") - and USE_TRITON_INTERPRET -) -DEBUG_TRITON_DETAIL = ( - os.environ.get("DEBUG_TRITON_DETAIL", "0").lower() in ("1", "true", "yes") - and USE_TRITON_INTERPRET -) +if DEBUG >= 2: + os.environ["TRITON_INTERPRET"] = "1" BWD_MODE: Literal["fused", "fused_atomic", "split"] = "fused" USE_EXP2 = True PHILOX_SEED = 0x1BF58 @@ -70,614 +122,51 @@ def is_rdna(self) -> bool: SHAPE_EXPECTATIONS: Literal["exact", "rounded"] = "exact" -def tensor_stats(name: str, t: torch.Tensor) -> str: - """Return a string with tensor shape, dtype, and distribution stats for debugging.""" - if t is None: - return f"{name}: None" - flat = t.float().flatten() - return ( - f"{name}: shape={tuple(t.shape)}, dtype={t.dtype}, " - f"min={flat.min().item():.6g}, max={flat.max().item():.6g}, " - f"mean={flat.mean().item():.6g}, median={flat.median().item():.6g}, " - f"std={flat.std().item():.6g}" - ) - - -# ------------------------------- -# Input Helper -# ------------------------------- -def random_seqlens_composition(SEQ_LEN, BATCH): - # generate a random composition of N into Z positive parts. - idx = torch.randperm(SEQ_LEN - 1)[: BATCH - 1] + 1 - idx, _ = torch.sort(idx) - breakpoints = torch.cat( - [ - torch.tensor([0], dtype=torch.long), - idx, - torch.tensor([SEQ_LEN], dtype=torch.long), - ] - ) - seqlens = (breakpoints[1:] - breakpoints[:-1]).to(torch.int32) - return seqlens - - -def generate_varlen_tensor( - total_seqlen: int, - num_heads: int, - head_size: int, - batch_size: Optional[int] = None, - equal_seqlens: bool = False, - device: str = "cuda", - dtype: torch.dtype = torch.float16, - mode: Literal["random", "ones", "incremental", "identity"] = "random", -): - if DEBUG: - print("total_seqlen", total_seqlen) - print("num_heads", num_heads) - print("head_size", head_size) - - # save fp8 type - is_fp8_dtype = is_dtype_fp8(dtype) - if is_fp8_dtype: - og_fp8_dtype = dtype - dtype = torch.float32 - - # get valid batch_size - if batch_size is None: - valid_batch_sizes = [ - bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen - ] - batch_size = random.choice(valid_batch_sizes) - - # get seqlens - if equal_seqlens: - seqlens = torch.full( - (batch_size,), total_seqlen // batch_size, dtype=torch.int32, device=device - ) - seqlens[-1] += total_seqlen % batch_size - else: - seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) - - # create cumulative sequence lengths - cu_seqlens = ( - torch.cat( - [torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)] - ) - .to(torch.int32) - .to(device=device) - ) - max_seqlen = torch.max(seqlens).to(torch.int32).item() - - # create varlen tensor based on mode - if mode == "incremental": - x = torch.zeros(total_seqlen, num_heads, head_size, dtype=dtype, device=device) - for i in range(batch_size): - start = cu_seqlens[i].item() - end = cu_seqlens[i + 1].item() - length = end - start - - x[start:end, :, :] = ( - torch.arange(length, dtype=dtype, device=device) - .view(length, 1, 1) - .expand(length, num_heads, head_size) - ) - elif mode == "identity": - x = torch.zeros(total_seqlen, num_heads, head_size, dtype=dtype, device=device) - # for each batch, create identity pattern within that batch's sequence - for i in range(batch_size): - start = cu_seqlens[i].item() - end = cu_seqlens[i + 1].item() - length = end - start - - # create identity pattern for positions within this batch - for pos in range(min(length, head_size)): - x[start + pos, :, pos] = 1.0 - elif mode == "random": - x = torch.randn( - (total_seqlen, num_heads, head_size), dtype=dtype, device=device - ) - elif mode == "ones": - x = torch.ones((total_seqlen, num_heads, head_size), dtype=dtype, device=device) - else: - raise ValueError(f"Unkown mode {mode}") - - if is_fp8_dtype: - # cast to fp8 - x, descale_x = cast_to_fp8( - x, og_fp8_dtype, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - x.requires_grad_() - return x, cu_seqlens, max_seqlen, descale_x - else: - x.requires_grad_() - return x, cu_seqlens, max_seqlen - - -def generate_bshd_tensor( - BATCH, - SEQ_LEN, - NUM_HEADS, - D_HEAD, - dtype: torch.dtype = torch.float16, - device="cuda", - mode: Literal["random", "ones", "incremental", "identity"] = "random", -): - # save fp8 type - is_fp8_dtype = is_dtype_fp8(dtype) - if is_fp8_dtype: - og_fp8_dtype = dtype - dtype = torch.float32 - - # gen tensor based on mode - tensor_shape = (BATCH, SEQ_LEN, NUM_HEADS, D_HEAD) - if mode == "incremental": - x = ( - torch.arange(SEQ_LEN, dtype=dtype, device=device) - .view(1, SEQ_LEN, 1, 1) - .expand(*tensor_shape) - .contiguous() - ) - elif mode == "identity": - x = torch.zeros(tensor_shape, dtype=dtype, device=device) - # create identity pattern: position i has value 1 at dimension i - for i in range(min(SEQ_LEN, D_HEAD)): - x[:, i, :, i] = 1.0 - elif mode == "random": - x = torch.randn(tensor_shape, dtype=dtype, device=device) - elif mode == "ones": - x = torch.ones(tensor_shape, dtype=dtype, device=device) - else: - raise ValueError(f"Unkown mode {mode}") - - if is_fp8_dtype: - # cast to fp8 - x, descale_x = cast_to_fp8(x, og_fp8_dtype, "bshd") - x.requires_grad_() - return x, descale_x - else: - x.requires_grad_() - return x - - -def generate_bhsd_tensor( - BATCH, - NUM_HEADS, - SEQ_LEN, - D_HEAD, - dtype: torch.dtype = torch.float16, - device="cuda", - mode: Literal["random", "ones", "incremental", "identity"] = "random", -): - # save fp8 type - is_fp8_dtype = is_dtype_fp8(dtype) - if is_fp8_dtype: - og_fp8_dtype = dtype - dtype = torch.float32 - - # gen tensor based on mode - tensor_shape = (BATCH, NUM_HEADS, SEQ_LEN, D_HEAD) - if mode == "incremental": - x = ( - torch.arange(SEQ_LEN, dtype=dtype, device=device) - .view(1, 1, SEQ_LEN, 1) - .expand(*tensor_shape) - .contiguous() - ) - elif mode == "identity": - x = torch.zeros(tensor_shape, dtype=dtype, device=device) - # create identity pattern: position i has value 1 at dimension i - for i in range(min(SEQ_LEN, D_HEAD)): - x[:, :, i, i] = 1.0 - elif mode == "random": - x = torch.randn(tensor_shape, dtype=dtype, device=device) - elif mode == "ones": - x = torch.ones(tensor_shape, dtype=dtype, device=device) - else: - raise ValueError(f"Unkown mode {mode}") - - if is_fp8_dtype: - raise ValueError("fp8 not supported for bhsd yet") - else: - x.requires_grad_() - return x - - -def generate_bshd_qkv_packed( - BATCH, - SEQ_LEN, - NUM_HEADS, - D_HEAD, - dtype: torch.dtype = torch.float16, - device="cuda", - DEBUG_INPUT=False, -): - """Generate QKV packed tensor with shape (BATCH, SEQ_LEN, 3, NUM_HEADS, D_HEAD)""" - # save fp8 type - is_fp8_dtype = is_dtype_fp8(dtype) - if is_fp8_dtype: - og_fp8_dtype = dtype - dtype = torch.float32 - - # gen tensor - tensor_shape = (BATCH, SEQ_LEN, 3, NUM_HEADS, D_HEAD) - if DEBUG_INPUT: - x = ( - torch.arange(SEQ_LEN, dtype=dtype, device=device) - .view(1, SEQ_LEN, 1, 1, 1) - .expand(*tensor_shape) - .contiguous() - ) - else: - x = torch.randn(tensor_shape, dtype=dtype, device=device) - - if is_fp8_dtype: - # cast to fp8 - need to handle the packed dimension - raise NotImplementedError("FP8 not supported for QKV packing yet") - else: - x.requires_grad_() - return x - - -def generate_bshd_kv_packed( - BATCH, - SEQ_LEN, - NUM_HEADS, - D_HEAD, - dtype: torch.dtype = torch.float16, - device="cuda", - DEBUG_INPUT=False, -): - """Generate KV packed tensor with shape (BATCH, SEQ_LEN, 2, NUM_HEADS, D_HEAD)""" - # save fp8 type - is_fp8_dtype = is_dtype_fp8(dtype) - if is_fp8_dtype: - og_fp8_dtype = dtype - dtype = torch.float32 - - # gen tensor - tensor_shape = (BATCH, SEQ_LEN, 2, NUM_HEADS, D_HEAD) - if DEBUG_INPUT: - x = ( - torch.arange(SEQ_LEN, dtype=dtype, device=device) - .view(1, SEQ_LEN, 1, 1, 1) - .expand(*tensor_shape) - .contiguous() - ) - else: - x = torch.randn(tensor_shape, dtype=dtype, device=device) - - if is_fp8_dtype: - # cast to fp8 - need to handle the packed dimension - raise NotImplementedError("FP8 not supported for KV packing yet") - else: - x.requires_grad_() - return x - - -def generate_bhsd_qkv_packed( - BATCH, - NUM_HEADS, - SEQ_LEN, - D_HEAD, - dtype: torch.dtype = torch.float16, - device="cuda", - DEBUG_INPUT=False, -): - """Generate QKV packed tensor with shape (BATCH, 3, NUM_HEADS, SEQ_LEN, D_HEAD)""" - # save fp8 type - is_fp8_dtype = is_dtype_fp8(dtype) - if is_fp8_dtype: - og_fp8_dtype = dtype - dtype = torch.float32 - - # gen tensor - tensor_shape = (BATCH, 3, NUM_HEADS, SEQ_LEN, D_HEAD) - if DEBUG_INPUT: - x = ( - torch.arange(SEQ_LEN, dtype=dtype, device=device) - .view(1, 1, 1, SEQ_LEN, 1) - .expand(*tensor_shape) - .contiguous() - ) - else: - x = torch.randn(tensor_shape, dtype=dtype, device=device) - - if is_fp8_dtype: - # cast to fp8 - need to handle the packed dimension - raise NotImplementedError("FP8 not supported for QKV packing yet") - else: - x.requires_grad_() - return x - - -def generate_bhsd_kv_packed( - BATCH, - NUM_HEADS, - SEQ_LEN, - D_HEAD, - dtype: torch.dtype = torch.float16, - device="cuda", - DEBUG_INPUT=False, -): - """Generate KV packed tensor with shape (BATCH, 2, NUM_HEADS, SEQ_LEN, D_HEAD)""" - # save fp8 type - is_fp8_dtype = is_dtype_fp8(dtype) - if is_fp8_dtype: - og_fp8_dtype = dtype - dtype = torch.float32 - - # gen tensor - tensor_shape = (BATCH, 2, NUM_HEADS, SEQ_LEN, D_HEAD) - if DEBUG_INPUT: - x = ( - torch.arange(SEQ_LEN, dtype=dtype, device=device) - .view(1, 1, 1, SEQ_LEN, 1) - .expand(*tensor_shape) - .contiguous() - ) - else: - x = torch.randn(tensor_shape, dtype=dtype, device=device) - - if is_fp8_dtype: - # cast to fp8 - need to handle the packed dimension - raise NotImplementedError("FP8 not supported for KV packing yet") - else: - x.requires_grad_() - return x - - -def generate_varlen_qkv_packed( - total_seqlen: int, - num_heads: int, - head_size: int, - batch_size: Optional[int] = None, - equal_seqlens: bool = False, - device: str = "cuda", - dtype: torch.dtype = torch.float16, - DEBUG_INPUT: bool = False, -): - """Generate varlen QKV packed tensor with shape (total_seqlen, 3, num_heads, head_size)""" - if DEBUG: - print("generate_varlen_qkv_packed") - print("total_seqlen", total_seqlen) - print("num_heads", num_heads) - print("head_size", head_size) - - # save fp8 type - is_fp8_dtype = is_dtype_fp8(dtype) - if is_fp8_dtype: - og_fp8_dtype = dtype - dtype = torch.float32 - - # get valid batch_size - if batch_size is None: - valid_batch_sizes = [ - bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen - ] - batch_size = random.choice(valid_batch_sizes) - - # get seqlens - if equal_seqlens: - seqlens = torch.full( - (batch_size,), total_seqlen // batch_size, dtype=torch.int32, device=device - ) - seqlens[-1] += total_seqlen % batch_size - else: - seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) - - # create cumulative sequence lengths - cu_seqlens = ( - torch.cat( - [torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)] - ) - .to(torch.int32) - .to(device=device) - ) - max_seqlen = torch.max(seqlens).to(torch.int32).item() - - # create varlen qkv packed tensor - if DEBUG_INPUT: - x = torch.zeros( - total_seqlen, 3, num_heads, head_size, dtype=dtype, device=device - ) - for i in range(batch_size): - start = cu_seqlens[i].item() - end = cu_seqlens[i + 1].item() - length = end - start - - x[start:end, :, :, :] = ( - torch.arange(length, dtype=dtype, device=device) - .view(length, 1, 1, 1) - .expand(length, 3, num_heads, head_size) - ) - else: - x = torch.randn( - (total_seqlen, 3, num_heads, head_size), dtype=dtype, device=device - ) - - if is_fp8_dtype: - # cast to fp8 - need to handle the packed dimension - raise NotImplementedError("FP8 not supported for QKV packing yet") - else: - x.requires_grad_() - return x, cu_seqlens, max_seqlen - - -def generate_varlen_kv_packed( - total_seqlen: int, - num_heads: int, - head_size: int, - batch_size: Optional[int] = None, - equal_seqlens: bool = False, - device: str = "cuda", - dtype: torch.dtype = torch.float16, - DEBUG_INPUT: bool = False, -): - """Generate varlen KV packed tensor with shape (total_seqlen, 2, num_heads, head_size)""" - if DEBUG: - print("generate_varlen_kv_packed") - print("total_seqlen", total_seqlen) - print("num_heads", num_heads) - print("head_size", head_size) - - # save fp8 type - is_fp8_dtype = is_dtype_fp8(dtype) - if is_fp8_dtype: - og_fp8_dtype = dtype - dtype = torch.float32 - - # get valid batch_size - if batch_size is None: - valid_batch_sizes = [ - bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen - ] - batch_size = random.choice(valid_batch_sizes) - - # get seqlens - if equal_seqlens: - seqlens = torch.full( - (batch_size,), total_seqlen // batch_size, dtype=torch.int32, device=device - ) - seqlens[-1] += total_seqlen % batch_size - else: - seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) - - # create cumulative sequence lengths - cu_seqlens = ( - torch.cat( - [torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)] - ) - .to(torch.int32) - .to(device=device) - ) - max_seqlen = torch.max(seqlens).to(torch.int32).item() - - # create varlen kv packed tensor - if DEBUG_INPUT: - x = torch.zeros( - total_seqlen, 2, num_heads, head_size, dtype=dtype, device=device - ) - for i in range(batch_size): - start = cu_seqlens[i].item() - end = cu_seqlens[i + 1].item() - length = end - start - - x[start:end, :, :, :] = ( - torch.arange(length, dtype=dtype, device=device) - .view(length, 1, 1, 1) - .expand(length, 2, num_heads, head_size) - ) - else: - x = torch.randn( - (total_seqlen, 2, num_heads, head_size), dtype=dtype, device=device - ) - - if is_fp8_dtype: - # cast to fp8 - need to handle the packed dimension - raise NotImplementedError("FP8 not supported for KV packing yet") - else: - x.requires_grad_() - return x, cu_seqlens, max_seqlen - - -# ------------------------------- -# Alibi -# ------------------------------- -@triton.jit -def compute_alibi_block( - alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False -): - # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix - # for casual mask we want something like this where (1 is kept and 0 is masked) - # seqlen_q = 2 and seqlen_k = 5 - # 1 1 1 1 0 - # 1 1 1 1 1 - # seqlen_q = 5 and seqlen_k = 2 - # 0 0 - # 0 0 - # 0 0 - # 1 0 - # 1 1 - # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal - # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False - # 1. offs_m[:,None] = [[0], - # [1], - # 2. offs_m[:,None] + seqlen_k = [[5], - # [6], - # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], - # [4], - # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], - # [4], [ 4, 3, 2, 1, 0]] - # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], - # [ -4, -3, -2, -1, 0]], - relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - if transpose: - return alibi_block.T - else: - return alibi_block - - # ------------------------------- # FP8 # ------------------------------- -def is_dtype_fp8(dtype) -> bool: - supported = { - torch.float8_e4m3fnuz, - torch.float8_e4m3fn, - torch.float8_e5m2, - torch.float8_e5m2fnuz, - } - if dtype not in supported: - return False - return True - - -_RECOMMENDED_FP8_REPLACEMENTS = { - "gfx942": { - torch.float8_e4m3fn: torch.float8_e4m3fnuz, - torch.float8_e5m2: torch.float8_e5m2fnuz, - }, -} +_FP8_DTYPES = frozenset({ + torch.float8_e4m3fnuz, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e5m2fnuz, +}) -def get_recommended_fp8_dtype(x): - dtype = x.dtype if isinstance(x, torch.Tensor) else x - if not is_dtype_fp8(dtype): - return dtype - arch = get_arch() - return _RECOMMENDED_FP8_REPLACEMENTS.get(arch, {}).get(dtype, dtype) +def is_fp8( + x: Union[torch.dtype, torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]], +) -> bool: + """Check if dtype/tensor(s) are FP8. + This is a pure function - it only checks dtypes, not architecture support. + Use `get_arch().supports_fp8` to check if the current GPU supports FP8. -def is_fp8(x) -> bool: - """Return whether tensor(s) use FP8. + Args: + x: A dtype, tensor, or list/tuple of tensors to check. - Accepts either a single tensor or a list/tuple of tensors. + Returns: + True if FP8, False otherwise. - Rules: - * Single tensor: return True if FP8 (after arch validation), else False. - * Multiple tensors: - - If all tensors are FP8 -> return True. - - If none are FP8 -> return False. - - If a mix of FP8 and non-FP8 -> raise ValueError. + Rules for multiple tensors: + - If all tensors are FP8 -> return True. + - If none are FP8 -> return False. + - If a mix of FP8 and non-FP8 -> raise ValueError. Empty list/tuple returns False. """ + # Handle dtype directly + if isinstance(x, torch.dtype): + return x in _FP8_DTYPES - def _is_fp8_single(t: torch.Tensor) -> bool: - if is_dtype_fp8(t.dtype): - arch = get_arch() - if arch not in ("gfx942", "gfx950"): - raise RuntimeError( - f"{arch} is not in the list of supported architectures for FP8" - ) - return True - return False + # Handle single tensor + if isinstance(x, torch.Tensor): + return x.dtype in _FP8_DTYPES + # Handle list/tuple of tensors if isinstance(x, (list, tuple)): if len(x) == 0: return False - flags = [_is_fp8_single(t) for t in x] + flags = [t.dtype in _FP8_DTYPES for t in x] if all(flags): return True if not any(flags): @@ -685,244 +174,12 @@ def _is_fp8_single(t: torch.Tensor) -> bool: raise ValueError( "Mixed FP8 and non-FP8 tensors provided; either all or none must be FP8." ) - else: - return _is_fp8_single(x) - - -@triton.jit -def compute_fp8_scaling_factors(x, fp8_max: tl.constexpr): - # compute fp8 scaling and descaling factor for a block - x_amax = tl.max(tl.abs(x)) # NOTE: abs deals with negative values - x_amax = tl.where(x_amax <= 1e-9, 1e-9, x_amax) - scale_x = fp8_max / x_amax - descale_x = x_amax / fp8_max - return scale_x, descale_x - - -@triton.jit -def _cast_varlen_to_fp8_kernel_2d( - X, - X_fp8, - Descale, - cu_seqlens, - H, - MAX_SEQLEN, - stride_batch, - stride_seq, - stride_head, - stride_dim, - stride_out_batch, - stride_out_seq, - stride_out_head, - stride_out_dim, - stride_desc_batch, - stride_desc_head, - FP8_CLAMP_VAL, - FP8_MAX, - BLOCK_SIZE: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, - IS_VARLEN: tl.constexpr, -): - # Process one (batch, head) pair per kernel - b_id = tl.program_id(0) - h_id = tl.program_id(1) - - # Get sequence bounds for this batch - if IS_VARLEN: - seq_start = tl.load(cu_seqlens + b_id) - seq_end = tl.load(cu_seqlens + b_id + 1) - seqlen = seq_end - seq_start - else: - seq_start = 0 - seqlen = MAX_SEQLEN - - # initialize max value tracker - x_max_val = 0.0 - - # STEP 1: Find max absolute value across the entire sequence - num_of_blocks = tl.cdiv(seqlen, BLOCK_SIZE) - for blk_idx in range(0, num_of_blocks): - # print("blk_idx:", blk_idx) - # offsets - offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs_dim = tl.arange(0, HEAD_DIM) - - # Create mask for valid elements - mask_seq = offs_seq[:, None] < seqlen - if ACTUAL_HEAD_DIM != HEAD_DIM: - mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM - mask_seq = mask_seq & mask_dim - - # Load block - adj_x = ( - b_id * stride_batch - + h_id * stride_head - + seq_start * stride_seq - + offs_seq[:, None] * stride_seq - + offs_dim[None, :] * stride_dim - ) - x_block = tl.load(X + adj_x, mask=mask_seq, other=0.0) - # print("x_block:", x_block) - - # Find max absolute value in this block - block_max = tl.max(tl.abs(x_block)) - # print("block_max:", block_max) - - # Update overall max - x_max_val = tl.maximum(x_max_val, block_max) - # print("x_max_val:", x_max_val) - - # clamp to avoid division by zero issues - x_max_val = tl.maximum(x_max_val, FP8_CLAMP_VAL) - - # compute scale and descale factors for the entire sequence - scale = FP8_MAX / x_max_val - descale = x_max_val / FP8_MAX - - # store descale factor for this (batch, head) pair - desc_ptr = Descale + b_id * stride_desc_batch + h_id # * stride_desc_head - tl.store(desc_ptr, descale) - - # STEP 2: Apply scaling to the entire sequence and convert to FP8 - for blk_idx in range(0, num_of_blocks): - # offsets - offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs_dim = tl.arange(0, HEAD_DIM) - - # Create mask for valid elements - mask_seq = offs_seq[:, None] < seqlen - if ACTUAL_HEAD_DIM != HEAD_DIM: - mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM - mask_seq = mask_seq & mask_dim - - # Load block - Using the fixed addressing - addr = ( - b_id * stride_batch - + h_id * stride_head - + seq_start * stride_seq - + offs_seq[:, None] * stride_seq - + offs_dim[None, :] * stride_dim - ) - x_block = tl.load(X + addr, mask=mask_seq, other=0.0) - - # Apply scale and convert to FP8 - x_fp8_block = (x_block * scale).to(X_fp8.type.element_ty) - - # Store results - addr_out = ( - b_id * stride_out_batch - + h_id * stride_out_head - + seq_start * stride_out_seq - + offs_seq[:, None] * stride_out_seq - + offs_dim[None, :] * stride_out_dim - ) - tl.store(X_fp8 + addr_out, x_fp8_block, mask=mask_seq) - - -def cast_to_fp8( - x: torch.Tensor, - fp8_dtype: torch.dtype, - layout: Literal["bshd", "thd"], - clamp_val: float = 1e-9, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, -) -> tuple[torch.Tensor, torch.Tensor]: - if False: - print() - print("cast_to_fp8") - print("x:", x, x.shape) - print("fp8_dtype:", fp8_dtype) - print("cu_seqlens:", cu_seqlens) - print("max_seqlen:", max_seqlen) - print("clamp_val:", clamp_val) - - # check types are valid - assert x.dtype in { - torch.float16, - torch.float32, - torch.float64, - torch.bfloat16, - } and is_dtype_fp8(fp8_dtype), f"Cannot cast {x.dtype} to {fp8_dtype}" - - # extract dimensions - batch, max_seqlen_final, num_heads, head_dim = get_shape_from_layout( - x, layout, cu_seqlens, max_seqlen - ) - is_varlen = layout == "thd" - fp8_max = torch.finfo(fp8_dtype).max - if False: - print("batch:", batch) - print("max_seqlen_final:", max_seqlen_final) - print("num_heads:", num_heads) - print("head_dim:", head_dim) - - # get closest power of 2 for head_dim - padded_head_dim = 1 << (head_dim - 1).bit_length() - padded_head_dim = max(padded_head_dim, 32) - - # kernel params - x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) - descale_factors = torch.zeros( - (batch, num_heads), device=x.device, dtype=torch.float32 - ) - BLOCK_SIZE = 128 - # calculate strides - stride_batch, stride_head, stride_seq, stride_dim = get_stride_from_layout( - x, layout - ) - stride_out_batch, stride_out_head, stride_out_seq, stride_out_dim = ( - get_stride_from_layout(x_fp8, layout) - ) - stride_desc_batch, stride_desc_head = descale_factors.stride() - - if False: - print("stride_batch", stride_batch) - print("stride_head", stride_head) - print("stride_seq", stride_seq) - print("stride_dim", stride_dim) - print("stride_out_batch", stride_out_batch) - print("stride_out_head", stride_out_head) - print("stride_out_seq", stride_out_seq) - print("stride_out_dim", stride_out_dim) - print("stride_desc_batch", stride_desc_batch) - print("stride_desc_head", stride_desc_head) - - grid = (batch, num_heads) - _cast_varlen_to_fp8_kernel_2d[grid]( - x, - x_fp8, - descale_factors, - cu_seqlens, - num_heads, - max_seqlen_final, - stride_batch, - stride_seq, - stride_head, - stride_dim, - stride_out_batch, - stride_out_seq, - stride_out_head, - stride_out_dim, - stride_desc_batch, - stride_desc_head, - clamp_val, - fp8_max, - BLOCK_SIZE=BLOCK_SIZE, - HEAD_DIM=padded_head_dim, - ACTUAL_HEAD_DIM=head_dim, - IS_VARLEN=is_varlen, - ) - - if False: - print("x_fp8:", x_fp8, x_fp8.shape) - print("descale_factors:", descale_factors, descale_factors.shape) - return x_fp8, descale_factors + raise TypeError(f"Expected dtype, Tensor, or sequence of Tensors, got {type(x)}") # ------------------------------- -# Misc +# Shape/Stride Helpers # ------------------------------- def get_shape_from_layout( x: torch.Tensor, @@ -930,6 +187,7 @@ def get_shape_from_layout( cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, ) -> tuple[int, int, int, int]: + """Extract (batch, max_seqlen, num_heads, head_dim) from tensor based on layout.""" if layout == "bhsd": batch, num_heads, max_seqlen_final, head_dim = x.shape elif layout == "bshd": @@ -948,35 +206,15 @@ def get_shape_from_layout( head_dim, ) else: - assert False, "Got unsupported layout." + raise ValueError(f"Got unsupported layout: {layout}") return batch, max_seqlen_final, num_heads, head_dim -def get_shapes_from_layout( - q, - k, - layout, - cu_seqlens_q=None, - cu_seqlens_k=None, - max_seqlen_q=None, - max_seqlen_k=None, -): - batch_q, seqlen_q, nheads_q, head_size_q = get_shape_from_layout( - q, layout, cu_seqlens_q, max_seqlen_q - ) - batch_k, seqlen_k, nheads_k, head_size_k = get_shape_from_layout( - k, layout, cu_seqlens_k, max_seqlen_k - ) - - # assert - assert batch_q == batch_k - assert head_size_q == head_size_k - - return batch_q, nheads_q, nheads_k, head_size_q, seqlen_q, seqlen_k - - -def get_stride_from_layout(x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"]): +def get_stride_from_layout( + x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"] +) -> tuple[int, int, int, int]: + """Get strides in (batch, head, seq, dim) order for the given layout.""" if layout == "thd": strides = (0, x.stride(1), x.stride(0), x.stride(2)) elif layout == "bhsd": @@ -984,547 +222,43 @@ def get_stride_from_layout(x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd elif layout == "bshd": strides = (x.stride(0), x.stride(2), x.stride(1), x.stride(3)) else: - assert False, "Got unsupported layout." + raise ValueError(f"Got unsupported layout: {layout}") return strides -def get_shape_and_strides_from_layout( - x: torch.Tensor, - layout: Literal["bshd", "bhsd", "thd"], - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, -): - return get_shape_from_layout( - x, layout, cu_seqlens, max_seqlen - ), get_stride_from_layout(x, layout) - - -def get_strides_from_layout(q, k, v, o, layout): - q_strides = get_stride_from_layout(q, layout) - k_strides = get_stride_from_layout(k, layout) - v_strides = get_stride_from_layout(v, layout) - o_strides = get_stride_from_layout(o, layout) - return q_strides, k_strides, v_strides, o_strides - - -def get_padded_headsize(size): - # Get closest power of 2 over or equal to 32. - padded_d_model = 1 << (size - 1).bit_length() +def get_padded_headsize(size: int) -> int: + """Get closest power of 2 over or equal to 32.""" # Smallest head_dim supported is 16. If smaller, the tile in the # kernel is padded - there is no padding in memory for any dims. + padded_d_model = 1 << (size - 1).bit_length() padded_d_model = max(padded_d_model, 16) return padded_d_model -def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k): - q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze( - -1 - ) # (N_CTX_Q, 1) - k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze( - 0 - ) # (1, N_CTX_K) - relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) - return ( - -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos - ) # (Z, H, N_CTX_Q, N_CTX_K) - - -def round_multiple(x, m): - return (x + m - 1) // m * m - - -def save_tensor_to_csv(tensor, filename, decimal_places=2): - """ - save a 2d tensor to csv file - - args: - tensor: torch tensor of shape [rows, cols] - filename: output csv filename - decimal_places: number of decimal places (default: 2) - """ - # ensure tensor is 2d - if tensor.ndim != 2: - raise ValueError(f"tensor must be 2d, got shape {tensor.shape}") - - # ensure filename ends with .csv - if not filename.endswith(".csv"): - filename = filename + ".csv" - - # save to csv using numpy - np.savetxt( - filename, - tensor.detach().cpu().numpy(), - delimiter=",", - fmt=f"%.{decimal_places}f", - ) - - # ------------------------------- -# Dropouts +# Misc helpers # ------------------------------- -def create_dropout_mask(dropout_p, shape, seed): - device = "cuda" - rand_vals = torch.rand( - shape, - generator=torch.Generator(device=device).manual_seed(seed), - device=device, - dtype=torch.float32, - ) - return rand_vals > dropout_p - - -def create_dropout_mask_varlen( - dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed -): - device = "cuda" - qlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - klens = cu_seqlens_k[1:] - cu_seqlens_k[:-1] - max_qlen = qlens.max() - max_klen = klens.max() - dropout_mask = torch.zeros((batch, nheads_q, max_qlen, max_klen), device=device) - for b in range(batch): - qlen = qlens[b] - klen = klens[b] - rand_vals = torch.rand( - (nheads_q, qlen, klen), - generator=torch.Generator(device=device).manual_seed(philox_seed), - device=device, - dtype=torch.float32, - ) - submask = rand_vals > dropout_p - dropout_mask[b, :, :qlen, :klen] = submask - - return dropout_mask - - -def write_dropout_mask(x, tensor_name="tensor"): - batch, head, seqlen_m, seqlen_n = x.shape - x = x.tolist() - - with open(f"{tensor_name}.csv", "w") as f: - writer = csv.writer(f) - for b in range(batch): - for h in range(head): - dropout_mask = x[b][h] - if True: - BLOCK_M = 64 - BLOCK_N = 64 - - # Calculate number of blocks in each dimension - m_blocks = math.ceil(seqlen_m / BLOCK_M) - n_blocks = math.ceil(seqlen_n / BLOCK_N) - - # Process each block - for m_block in range(m_blocks): - # Calculate row range for current block - row_start = m_block * BLOCK_M - row_end = min(row_start + BLOCK_M, seqlen_m) - - for n_block in range(n_blocks): - # Calculate column range for current block - col_start = n_block * BLOCK_N - col_end = min(col_start + BLOCK_N, seqlen_n) - - # Extract and write the current block - for row_idx in range(row_start, row_end): - row_data = dropout_mask[row_idx][col_start:col_end] - writer.writerow(row_data) - else: - writer.writerows(dropout_mask) - - -# ------------------------------- -# Rotary -# ------------------------------- -@triton.jit -def _rotary_kernel( - OUT, - X, - COS, - SIN, - CU_SEQLENS, - SEQLEN_OFFSETS, - seqlen, - nheads, - seqlen_ro, - stride_out_batch, - stride_out_seqlen, - stride_out_nheads, - stride_out_headdim, - stride_x_batch, - stride_x_seqlen, - stride_x_nheads, - stride_x_headdim, - ROTARY_DIM: tl.constexpr, - IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, - IS_VARLEN: tl.constexpr, - INTERLEAVED: tl.constexpr, - CONJUGATE: tl.constexpr, - BLOCK_H: tl.constexpr, - BLOCK_M: tl.constexpr, -): - BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM) - ROTARY_DIM_HALF = ROTARY_DIM // 2 - pid_head = tl.program_id(axis=0) - pid_m = tl.program_id(axis=1) - pid_batch = tl.program_id(axis=2) - - if not IS_VARLEN: - X = X + pid_batch * stride_x_batch - OUT = OUT + pid_batch * stride_out_batch - else: - start_idx = tl.load(CU_SEQLENS + pid_batch) - seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - X = X + start_idx * stride_x_seqlen - OUT = OUT + start_idx * stride_out_seqlen - - if pid_m * BLOCK_M >= seqlen: - return - - rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - if not IS_SEQLEN_OFFSETS_TENSOR: - rm_cs = rm + SEQLEN_OFFSETS - else: - rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) - - rk_half = tl.arange(0, BLOCK_K // 2) - COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) - SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) - mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF) - cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32) - sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32) - if CONJUGATE: - sin = -sin - - if not INTERLEAVED: - X = X + ( - rh[:, None, None] * stride_x_nheads - + rm[None, :, None] * stride_x_seqlen - + rk_half[None, None, :] * stride_x_headdim - ) - OUT = OUT + ( - rh[:, None, None] * stride_out_nheads - + rm[None, :, None] * stride_out_seqlen - + rk_half[None, None, :] * stride_out_headdim - ) - mask = ( - (rh[:, None, None] < nheads) - & (rm[None, :, None] < seqlen) - & (rk_half[None, None, :] < ROTARY_DIM_HALF) - ) - x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32) - x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0).to( - tl.float32 - ) - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - tl.store(OUT, o0, mask=mask) - tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask) - else: - rk = tl.arange(0, BLOCK_K) - X = X + ( - rh[:, None, None] * stride_x_nheads - + rm[None, :, None] * stride_x_seqlen - + rk[None, None, :] * stride_x_headdim - ) - OUT = OUT + ( - rh[:, None, None] * stride_out_nheads - + rm[None, :, None] * stride_out_seqlen - + rk[None, None, :] * stride_out_headdim - ) - mask = ( - (rh[:, None, None] < nheads) - & (rm[None, :, None] < seqlen) - & (rk[None, None, :] < ROTARY_DIM) - ) - x = tl.load(X, mask=mask, other=0.0).to(tl.float32) - x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2])) - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K]) - tl.store(OUT, o, mask=mask) - - -def _apply_rotary_kernel( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, - interleaved: bool = False, - inplace: bool = False, - conjugate: bool = False, -) -> torch.Tensor: - is_varlen = cu_seqlens is not None - if not is_varlen: - batch, seqlen, nheads, headdim = x.shape - else: - assert ( - max_seqlen is not None - ), "If cu_seqlens is passed, max_seqlen must also be provided" - total_seqlen, nheads, headdim = x.shape - batch_p_1 = cu_seqlens.shape[0] - batch = batch_p_1 - 1 - seqlen = max_seqlen - seqlen_ro, rotary_dim_half = cos.shape - assert sin.shape == cos.shape - rotary_dim = 2 * rotary_dim_half - assert rotary_dim <= headdim - assert headdim <= 256 - assert seqlen_ro >= seqlen - - cos, sin = cos.contiguous(), sin.contiguous() - if isinstance(seqlen_offsets, torch.Tensor): - assert seqlen_offsets.shape == (batch,) - assert seqlen_offsets.dtype in (torch.int32, torch.int64) - seqlen_offsets = seqlen_offsets.contiguous() - else: - assert seqlen_offsets + seqlen <= seqlen_ro - - out = torch.empty_like(x) if not inplace else x - if rotary_dim < headdim and not inplace: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - - # Block heuristics - BLOCK_M = 8 if rotary_dim <= 128 else 4 - grid = ( - triton.cdiv(nheads, 2), - triton.cdiv(seqlen, BLOCK_M), - batch, - ) - - with torch.cuda.device(x.device.index): - torch.library.wrap_triton(_rotary_kernel)[grid]( - out, - x, - cos, - sin, - cu_seqlens, - seqlen_offsets, - seqlen, - nheads, - seqlen_ro, - out.stride(0) if not is_varlen else 0, - out.stride(-3), - out.stride(-2), - out.stride(-1), - x.stride(0) if not is_varlen else 0, - x.stride(-3), - x.stride(-2), - x.stride(-1), - rotary_dim, - isinstance(seqlen_offsets, torch.Tensor), - is_varlen, - interleaved, - conjugate, - BLOCK_M=BLOCK_M, - BLOCK_H=2, - ) - return out - - -class _ApplyRotary(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool, - inplace: bool, - seqlen_offsets: Union[int, torch.Tensor], - cu_seqlens: Optional[torch.Tensor], - max_seqlen: Optional[int], - ): - out = _apply_rotary_kernel( - x, - cos, - sin, - seqlen_offsets=seqlen_offsets, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - interleaved=interleaved, - inplace=inplace, - conjugate=False, - ) - if isinstance(seqlen_offsets, int): - ctx.save_for_backward(cos, sin, cu_seqlens) - ctx.seqlen_offsets = seqlen_offsets - else: - ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) - ctx.seqlen_offsets = None - ctx.interleaved = interleaved - ctx.inplace = inplace - ctx.max_seqlen = max_seqlen - return out if not inplace else x - - @staticmethod - def backward(ctx, do: torch.Tensor): - seqlen_offsets = ctx.seqlen_offsets - if seqlen_offsets is None: - cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors - else: - cos, sin, cu_seqlens = ctx.saved_tensors - dx = _apply_rotary_kernel( - do, - cos, - sin, - seqlen_offsets=seqlen_offsets, - cu_seqlens=cu_seqlens, - max_seqlen=ctx.max_seqlen, - interleaved=ctx.interleaved, - inplace=ctx.inplace, - conjugate=True, - ) - return dx, None, None, None, None, None, None, None - - -def apply_rotary_emb( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool = False, - inplace: bool = False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[int] = None, -) -> torch.Tensor: - """Public API: apply rotary embeddings to tensor x. - - Args: - x: (B, S, H, D) if `cu_seqlens` is None else (total_S, H, D). - cos, sin: (S_rotary, rotary_dim/2) - interleaved: GPT-J style if True. - inplace: modify x in place (saves memory if rotary_dim == D). - seqlen_offsets: int or (B,) tensor of starting offsets per sequence (KV cache decode). - cu_seqlens: (B+1,) tensor enabling varlen mode. - max_seqlen: required when `cu_seqlens` is provided. - """ - # FP8 path: upcast to bfloat16 (preferred) or float16 for rotary math to avoid excessive error - original_dtype = x.dtype - is_fp8_input = original_dtype == getattr(torch, "float8_e4m3fn", None) - if is_fp8_input: - # Choose bf16 if available in cos.dtype path; otherwise fallback to float16 - target_dtype = ( - torch.bfloat16 - if cos.dtype == torch.bfloat16 or torch.cuda.is_bf16_supported() - else torch.float16 - ) - # Upcast x, cos, sin for computation (without modifying originals in-place) - x_up = x.to(target_dtype) - cos_up = cos.to(target_dtype) if cos.dtype != target_dtype else cos - sin_up = sin.to(target_dtype) if sin.dtype != target_dtype else sin - out_up = _ApplyRotary.apply( - x_up, - cos_up, - sin_up, - interleaved, - False, - seqlen_offsets, - cu_seqlens, - max_seqlen, - ) - # Cast result back to original fp8 dtype - if inplace: - x.copy_(out_up.to(original_dtype)) - return x - return out_up.to(original_dtype) - else: - return _ApplyRotary.apply( - x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen - ) - - -def apply_rotary( - q: torch.Tensor, - k_new: Optional[torch.Tensor], - cos: torch.Tensor, - sin: torch.Tensor, - *, - causal: bool, - local: bool, - interleaved: bool = False, - seqlen_offsets: Union[int, torch.Tensor] = 0, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """High-level rotary application used by AMD prefill & decode paths. - - Policy (matches test reference & legacy semantics): - - If causal OR local attention ⇒ apply rotary directly on (B, S, H, D). - - Else (non-causal global) ⇒ flatten heads into sequence: (B, 1, S*H, D), - apply rotary once, then unflatten back. - - k_new (incremental KV slice) is always rotated directly when provided. - - Args: - q: (B, S, H, D) - k_new: Optional (B, S_k, H_k, D) - cos, sin: rotary caches (S_rotary, rotary_dim/2) - causal: causal attention flag - local: sliding-window / local attention flag (pre-computed outside) - interleaved: GPT-J style rotary layout - seqlen_offsets: int or (B,) tensor of per-sequence start offsets - Returns: - (q_rot, k_new_rot) - """ - assert q.ndim == 4, f"Expected q shape (B,S,H,D), got {q.shape}" - B, S, H, D = q.shape - use_flatten = (not causal) and (not local) - - if use_flatten: - # Flatten (S,H) -> (S*H) with an added singleton dim to preserve expected 4D shape. - q_flat = q.reshape(B, S * H, D).unsqueeze(1) # (B, 1, S*H, D) - q_flat = apply_rotary_emb( - q_flat, - cos, - sin, - interleaved=interleaved, - seqlen_offsets=seqlen_offsets, - ) - # Restore shape back to (B, S, H, D) - q = q_flat.view(B, 1, S * H, D).reshape(B, S, H, D) - else: - q = apply_rotary_emb( - q, - cos, - sin, - interleaved=interleaved, - seqlen_offsets=seqlen_offsets, - ) - - if k_new is not None: - k_new = apply_rotary_emb( - k_new, - cos, - sin, - interleaved=interleaved, - seqlen_offsets=seqlen_offsets, - ) - return q, k_new +def round_multiple(x: int, m: int) -> int: + """Round x up to the nearest multiple of m.""" + return (x + m - 1) // m * m # ------------------------------- # Runtime info # ------------------------------- @functools.cache -def is_hip(): - return triton.runtime.driver.active.get_current_target().backend == "hip" +def is_hip() -> bool: + """Check if running on HIP (AMD) backend.""" + return bool(triton.runtime.driver.active.get_current_target().backend == "hip") @functools.cache def get_arch() -> GpuArch: """Get the current GPU architecture.""" - name = triton.runtime.driver.active.get_current_target().arch + name: str = triton.runtime.driver.active.get_current_target().arch if name in CDNA_ARCHS: return GpuArch(name=name, family="cdna") elif name in RDNA_ARCHS: return GpuArch(name=name, family="rdna") else: return GpuArch(name=name) - - -@functools.cache -def get_cu_count(): - return torch.cuda.get_device_properties( - torch.cuda.current_device() - ).multi_processor_count