From de09e4e982ad519c270a0bff53d2d3bc3ade6ab1 Mon Sep 17 00:00:00 2001 From: Michael Date: Tue, 30 Sep 2025 13:48:37 -0500 Subject: [PATCH 01/33] check cu count for gfx942 --- .../flash_attn_triton_amd/fwd_prefill.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index d1036f98c3f..cf866e0f0d0 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -150,13 +150,20 @@ def get_fwd_prefill_autotune_configs(): num_warps=4, ) elif ( - arch == "gfx942" and False - ): # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. - default_config = triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) + arch == "gfx942" + ): + if torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count < 304: + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + else: + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) else: default_config = triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, From 94df7bd1a1cc55877d38446a274961178d287ed1 Mon Sep 17 00:00:00 2001 From: Michael Date: Tue, 30 Sep 2025 13:58:55 -0500 Subject: [PATCH 02/33] create get_cu_count --- flash_attn/flash_attn_triton_amd/fwd_prefill.py | 3 ++- flash_attn/flash_attn_triton_amd/utils.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index cf866e0f0d0..bc77365a9c2 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -11,6 +11,7 @@ compute_alibi_block, compute_fp8_scaling_factors, get_arch, + get_cu_count, is_cdna, is_fp8, is_rdna, @@ -152,7 +153,7 @@ def get_fwd_prefill_autotune_configs(): elif ( arch == "gfx942" ): - if torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count < 304: + if get_cu_count() < 304: default_config = triton.Config( {"BLOCK_M": 128, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 71ed1c1c2de..1d5f6d82bf7 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -1736,6 +1736,9 @@ def is_hip(): def get_arch(): return triton.runtime.driver.active.get_current_target().arch +@functools.cache +def get_cu_count(): + return torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count @functools.cache def is_cdna(): From 01496469acfa32d16fc7d67b75cdb3500b6d8888 Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 1 Oct 2025 10:30:39 -0500 Subject: [PATCH 03/33] update repo root --- hopper/flash_attn_interface.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index b7ede7dc442..f5a31027976 100755 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -3,15 +3,18 @@ from typing import Optional, Union import os +import sys +from pathlib import Path import torch import torch.nn as nn USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" if USE_TRITON_ROCM: - import sys - sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - from flash_attn.flash_attn_triton_amd import flash_attn_3 as flash_attn_3_gpu + repo_root = Path(__file__).resolve().parent.parent + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + from flash_attn.flash_attn_triton_amd import flash_attn_3 as flash_attn_3_gpu # type: ignore else: # isort: off # We need to import the CUDA kernels after importing torch From 6e6b4a5f01bc6f6d4ee7ced505799eee54518a6a Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 1 Oct 2025 12:50:37 -0500 Subject: [PATCH 04/33] update forward tune --- flash_attn/flash_attn_triton_amd/bwd.py | 5 +- .../flash_attn_triton_amd/fwd_prefill.py | 183 +++++------------- .../flash_attn_triton_amd/interface_v3.py | 2 +- 3 files changed, 50 insertions(+), 140 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 4d4c22866d6..84e4fa1d61d 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -5,6 +5,7 @@ from typing import Literal, Optional from .utils import ( DEBUG, + AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, @@ -19,7 +20,7 @@ tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) -def get_bwd_configs(autotune = False): +def get_bwd_configs(autotune: bool): # default config if not autotune: # preprocess params @@ -159,7 +160,7 @@ def get_bwd_configs(autotune = False): (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys), -) = get_bwd_configs() +) = get_bwd_configs(AUTOTUNE) # This function computes delta given output Out and gradient DO diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index bc77365a9c2..a5855401864 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -24,48 +24,8 @@ tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) -# ------------------------------- -# Autotune -# ------------------------------- -def get_fwd_prefill_cdna_autotune_configs(): - return [ - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - # Fall-back config. - triton.Config( - {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ), - ], [ +def get_fwd_configs(autotune: bool): + keys = [ "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", @@ -77,117 +37,66 @@ def get_fwd_prefill_cdna_autotune_configs(): "HK", ] - -def get_fwd_prefill_rdna_autotune_configs(): - return [ - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 4, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 4, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 4, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - triton.Config( - {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - # Fall-back config. - triton.Config( - {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - ), - ], [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_BLOCK_DMODEL_QK", - "ACTUAL_BLOCK_DMODEL_V", - "IS_VARLEN", - "HQ", - "HK", - ] - - -def get_fwd_prefill_autotune_configs(): - if AUTOTUNE: - if is_rdna(): - return get_fwd_prefill_rdna_autotune_configs() - elif is_cdna(): - return get_fwd_prefill_cdna_autotune_configs() - else: - raise ValueError("Unknown Device Type") - else: + if not autotune: arch = get_arch() if arch == "gfx950": - default_config = triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, + cfg = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, num_warps=4, ) - elif ( - arch == "gfx942" - ): + elif arch == "gfx942": if get_cu_count() < 304: - default_config = triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) + cfg = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=2, + ) else: - default_config = triton.Config( + cfg = triton.Config( {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, num_warps=4, ) else: - default_config = triton.Config( + cfg = triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, num_warps=4, ) - - return [default_config], [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_BLOCK_DMODEL_QK", - "ACTUAL_BLOCK_DMODEL_V", - "IS_VARLEN", - "HQ", - "HK", - ] - - -fwd_prefill_autotune_configs, fwd_prefill_autotune_keys = ( - get_fwd_prefill_autotune_configs() -) + return [cfg], keys + + # ===================== Autotune Sweep ===================== + BLOCK_M_OPTIONS = [128, 64, 32, 16] + BLOCK_N_OPTIONS = [128, 64, 32, 16] + NUM_WARPS_OPTIONS = [2, 4, 8] + NUM_STAGES_OPTIONS = [1, 2] + WAVES_PER_EU_OPTIONS = [4, 2, 1] + PRE_LOAD_V_OPTIONS = [False] + + configs = [] + for bm in BLOCK_M_OPTIONS: + for bn in BLOCK_N_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for nw in NUM_WARPS_OPTIONS: + for ns in NUM_STAGES_OPTIONS: + for preload_v in PRE_LOAD_V_OPTIONS: + configs.append( + triton.Config( + { + "BLOCK_M": bm, + "BLOCK_N": bn, + "waves_per_eu": waves, + "PRE_LOAD_V": preload_v, + }, + num_stages=ns, + num_warps=nw, + ) + ) + + return configs, keys + +fwd_prefill_autotune_configs, fwd_prefill_autotune_keys = get_fwd_configs(AUTOTUNE) @triton.jit diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index 436077a8a7c..06fd4131e98 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -241,7 +241,7 @@ def fwd( ) if out is None: - out_dtype = torch.float32 if is_fp8(q) else q.dtype + out_dtype = torch.bfloat16 if is_fp8(q) else q.dtype if layout == "bshd": out = torch.zeros( q.shape[0], From e967c144411acd4a75891198e3e737521d48999b Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 1 Oct 2025 14:32:27 -0500 Subject: [PATCH 05/33] clean up load --- .../flash_attn_triton_amd/fwd_prefill.py | 42 ++++++++----------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index a5855401864..ff7c0b4cefd 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -159,28 +159,20 @@ def _attn_fwd_no_mask( v_ptrs = v_base_ptrs + start_n * stride_vk kv_offs_n = start_n + tl.arange(0, BLOCK_N) + # Load K if PADDED_HEAD_QK: - k_mask, k_mask_other = (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK), 0.0 + k_mask = offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK + k = tl.load(k_ptrs, mask=k_mask, other=0.0) else: - k_mask, k_mask_other = None, None + k = tl.load(k_ptrs) - if PADDED_HEAD_V: - v_mask, v_mask_other = (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V), 0.0 - else: - v_mask, v_mask_other = None, None - - # load k and if preload_v then v - k = ( - tl.load(k_ptrs, mask=k_mask, other=k_mask_other) - if PADDED_HEAD_QK - else tl.load(k_ptrs) - ) + # Optionally preload V if PRE_LOAD_V: - v = ( - tl.load(v_ptrs, mask=v_mask, other=v_mask_other) - if PADDED_HEAD_V - else tl.load(v_ptrs) - ) + if PADDED_HEAD_V: + v_mask = offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + else: + v = tl.load(v_ptrs) # setup qk accumlator qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) @@ -260,11 +252,11 @@ def _attn_fwd_no_mask( alpha = tl.math.exp(m_diff) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = ( - tl.load(v_ptrs, mask=v_mask, other=v_mask_other) - if PADDED_HEAD_V - else tl.load(v_ptrs) - ) + if PADDED_HEAD_V: + v_mask = offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + else: + v = tl.load(v_ptrs) # -- update m_i and l_i l_i = l_i * alpha + l_ij @@ -891,7 +883,6 @@ def attn_fwd( stride_q_descale_z, stride_k_descale_z, stride_v_descale_z, - SM_SCALE: tl.constexpr, LSE, Out, stride_qz, @@ -940,6 +931,7 @@ def attn_fwd( MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, + SM_SCALE: tl.constexpr, IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, @@ -1849,7 +1841,6 @@ def attention_forward_prefill_triton_impl( stride_q_descale_z, stride_k_descale_z, stride_v_descale_z, - sm_scale, softmax_lse, o, stride_qb, @@ -1897,6 +1888,7 @@ def attention_forward_prefill_triton_impl( ACTUAL_BLOCK_DMODEL_V=head_size_v, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, + SM_SCALE=sm_scale, IS_CAUSAL=causal, USE_SLIDING_WINDOW=use_sliding_window, WINDOW_SIZE_LEFT=window_size_left, From 371bec569bd324d66c95d3dd3569a90cc0619c91 Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 1 Oct 2025 16:22:07 -0500 Subject: [PATCH 06/33] use float8_e4m3fnuz --- flash_attn/flash_attn_triton_amd/utils.py | 25 ++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 1d5f6d82bf7..e929ad87e4f 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -890,18 +890,29 @@ def compute_alibi_block( # FP8 # ------------------------------- def is_dtype_fp8(dtype): - if dtype in { + supported = { torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz, - }: - if arch_supports_fp8(): - return True - else: - raise RuntimeError("This device doesnot support fp8") - else: + } + if dtype not in supported: return False + if not arch_supports_fp8(): + raise RuntimeError("This device does not support FP8 on this architecture") + + # check for architecture-specific restrictions + arch = get_arch() + if arch == "gfx942": + if dtype == torch.float8_e4m3fn: + replacement_dtype = torch.float8_e4m3fnuz + elif dtype == torch.float8_e5m2: + replacement_dtype = torch.float8_e5m2fnuz + else: + replacement_dtype = None + if replacement_dtype is not None: + raise TypeError(f"On {arch} use {replacement_dtype} instead of {dtype}") + return True def is_fp8(x): From 1f2aaa03238a6c3fde1a44d44ef63713266eedcc Mon Sep 17 00:00:00 2001 From: Michael Date: Thu, 2 Oct 2025 05:50:45 -0500 Subject: [PATCH 07/33] save --- flash_attn/flash_attn_triton_amd/bwd.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 84e4fa1d61d..b4f25daeee4 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -4790,7 +4790,6 @@ def attention_backward_triton_impl( ) if mode == "fused_atomics": - # Atomics path ignores layout & use_exp2; pass varlen metadata directly. return attention_backward_triton_fused_atomics_impl( do, q, From 56eba6125b4af2b6d4fda9d66df43a7f66701578 Mon Sep 17 00:00:00 2001 From: Michael Date: Thu, 2 Oct 2025 06:22:26 -0500 Subject: [PATCH 08/33] show bwd mode --- flash_attn/flash_attn_triton_amd/interface_v2.py | 4 ++-- flash_attn/flash_attn_triton_amd/interface_v3.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/interface_v2.py b/flash_attn/flash_attn_triton_amd/interface_v2.py index 134c4a76c12..9df299666eb 100644 --- a/flash_attn/flash_attn_triton_amd/interface_v2.py +++ b/flash_attn/flash_attn_triton_amd/interface_v2.py @@ -211,7 +211,7 @@ def bwd( # call implementation if DEBUG: - print("Using Triton implementation") + print(f"Using Triton implementation in {BWD_MODE} mode") delta = attention_backward_triton_impl( do=dout, q=q, @@ -494,7 +494,7 @@ def varlen_bwd( # call implementation if DEBUG: - print("Using Triton implementation") + print(f"Using Triton implementation in {BWD_MODE} mode") delta = attention_backward_triton_impl( do=dout, q=q, diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index 06fd4131e98..bb3f7315269 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -499,7 +499,7 @@ def bwd( # Call implementation if DEBUG: - print("Using Triton implementation (unified backward dispatcher)") + print(f"Using Triton implementation in {BWD_MODE} mode") delta = attention_backward_triton_impl( do=dout, q=q, From 0218cd2c8e26715f0ddedf6274ec3cb39b68a338 Mon Sep 17 00:00:00 2001 From: Michael Date: Fri, 3 Oct 2025 19:04:36 -0500 Subject: [PATCH 09/33] recommend fp8 --- flash_attn/flash_attn_triton_amd/bwd.py | 6 +- .../flash_attn_triton_amd/fwd_decode.py | 12 ++- .../flash_attn_triton_amd/fwd_prefill.py | 27 ++++++- .../flash_attn_triton_amd/interface_v3.py | 44 ++++++++--- flash_attn/flash_attn_triton_amd/utils.py | 76 +++++++++++++------ 5 files changed, 126 insertions(+), 39 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index b4f25daeee4..e47adae3321 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -3981,8 +3981,8 @@ def attention_backward_triton_split_fused_no_atomics_impl( stride_dob, stride_dom, stride_doh, stride_dod = do.stride() stride_lse_b, stride_lse_h, stride_lse_m = softmax_lse.stride() - # fp8 setup - moved after all assertions - IS_FP8 = is_fp8(q) + # fp8 + IS_FP8 = is_fp8([q, k, v]) if IS_FP8: FP8_MAX = torch.finfo(q.dtype).max # we already asserted that do, q, k, v all have the same dtype, so no need to check each one @@ -4331,7 +4331,7 @@ def attention_backward_triton_fused_atomics_impl( seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, ): - IS_FP8 = is_fp8(q) + IS_FP8 = is_fp8([q, k, v]) if IS_FP8: FP8_MAX = torch.finfo(q.dtype).max descale_strides = ( diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index bb7edad3494..5186f52ecdf 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -1,3 +1,4 @@ +import os import torch import triton import triton.language as tl @@ -10,6 +11,7 @@ apply_rotary, is_cdna, is_fp8, + get_recommended_fp8_dtype, ) @@ -1118,8 +1120,16 @@ def attention_forward_decode_triton_impl( stride_bt_b, stride_bt_s = 0, 0 # FP8 support - IS_FP8 = is_fp8(q) + IS_FP8 = is_fp8([q, k_cache, v_cache]) if IS_FP8: + CAST_TO_REC = str(os.getenv("CAST_TO_REC", "0")).lower() in ("1", "true", "yes", "on") + if CAST_TO_REC: + rec = get_recommended_fp8_dtype(q) + if q.dtype != rec: + raise TypeError( + f"FP8 dtype mismatch: received {q.dtype}, expected recommended {rec}. " + "Convert to the recommended FP8 dtype before calling (handled in interface)." + ) if (q_descale is None) or (k_descale is None) or (v_descale is None): import warnings diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index ff7c0b4cefd..1ee7c3e5b8e 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -17,6 +17,7 @@ is_rdna, create_dropout_mask, apply_rotary, + get_recommended_fp8_dtype, ) # NOTE: triton fails to import tl.constexprs so create them here for the file @@ -24,7 +25,7 @@ tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) -def get_fwd_configs(autotune: bool): +def get_fwd_configs(autotune: bool, use_fallback: bool = True): keys = [ "IS_CAUSAL", "dropout_p", @@ -37,7 +38,17 @@ def get_fwd_configs(autotune: bool): "HK", ] + # default configs if not autotune: + # TODO: don't use fallback config used for function correctness testing due to some configs leading error on the scale of 1e-1. + if use_fallback: + cfg = triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + + # get best config for the architecture arch = get_arch() if arch == "gfx950": cfg = triton.Config( @@ -64,6 +75,7 @@ def get_fwd_configs(autotune: bool): num_stages=1, num_warps=4, ) + return [cfg], keys # ===================== Autotune Sweep ===================== @@ -1716,12 +1728,19 @@ def attention_forward_prefill_triton_impl( ) # fp8 setup and assertions - IS_FP8 = is_fp8(q) + IS_FP8 = is_fp8([q, k, v]) if IS_FP8: - # we already asserted that q, k, v all have the same dtype, so no need to check each one - FP8_MAX = torch.finfo(q.dtype).max + CAST_TO_REC = str(os.getenv("CAST_TO_REC", "0")).lower() in ("1", "true", "yes", "on") + if CAST_TO_REC: + # check fp8 is the correct dtype for this architecture + rec = get_recommended_fp8_dtype(q) + if q.dtype != rec: + raise TypeError( + f"FP8 dtype mismatch: received {q.dtype}, expected recommended {rec} for this architecture. " + ) + # Check and create default descale tensors if not provided if (q_descale is None) or (k_descale is None) or (v_descale is None): import warnings diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index bb3f7315269..7c7b92424c1 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -1,10 +1,19 @@ -import torch import os +import warnings +import torch from typing import Optional, Union, Tuple from .fwd_prefill import attention_forward_prefill_triton_impl from .fwd_decode import attention_forward_decode_triton_impl from .bwd import attention_backward_triton_impl -from .utils import DEBUG, USE_EXP2, BWD_MODE, PHILOX_SEED, PHILOX_OFFSET, is_fp8 +from .utils import ( + DEBUG, + USE_EXP2, + BWD_MODE, + PHILOX_SEED, + PHILOX_OFFSET, + is_fp8, + get_recommended_fp8_dtype, +) def fwd( @@ -185,9 +194,6 @@ def fwd( "cu_seqlens_k_new is not yet supported in the AMD Triton backend" ) - # if seqlens_rotary is not None: - # raise NotImplementedError("seqlens_rotary is not yet supported in the AMD Triton backend") - # establish layout / varlen & max seq lens if cu_seqlens_q is not None: if len(q.shape) != 3: @@ -241,7 +247,7 @@ def fwd( ) if out is None: - out_dtype = torch.bfloat16 if is_fp8(q) else q.dtype + out_dtype = torch.bfloat16 if is_fp8([q, k, v]) else q.dtype if layout == "bshd": out = torch.zeros( q.shape[0], @@ -262,10 +268,30 @@ def fwd( else: out = out.zero_() - if is_fp8(q): - if (q_descale is None) or (k_descale is None) or (v_descale is None): - import warnings + if is_fp8([q, k, v]): + CAST_TO_REC = str(os.getenv("CAST_TO_REC", "0")).lower() in ("1", "true", "yes", "on") + if CAST_TO_REC: + # check recommended dtype + rec = get_recommended_fp8_dtype(q) + if rec != q.dtype: + warnings.warn( + f"Casting q,k,v from {q.dtype} to recommended {rec} for this architecture.", + UserWarning, + ) + q = q.to(rec) + k = k.to(rec) + v = v.to(rec) + if k_new is not None and is_fp8(k_new): + rec_kn = get_recommended_fp8_dtype(k_new) + if rec_kn != k_new.dtype: + k_new = k_new.to(rec_kn) + if v_new is not None and is_fp8(v_new): + rec_vn = get_recommended_fp8_dtype(v_new) + if rec_vn != v_new.dtype: + v_new = v_new.to(rec_vn) + + if (q_descale is None) or (k_descale is None) or (v_descale is None): warnings.warn( "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning, diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index e929ad87e4f..03e27e33660 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -889,7 +889,7 @@ def compute_alibi_block( # ------------------------------- # FP8 # ------------------------------- -def is_dtype_fp8(dtype): +def is_dtype_fp8(dtype) -> bool: supported = { torch.float8_e4m3fnuz, torch.float8_e4m3fn, @@ -898,25 +898,62 @@ def is_dtype_fp8(dtype): } if dtype not in supported: return False - if not arch_supports_fp8(): - raise RuntimeError("This device does not support FP8 on this architecture") + return True + + - # check for architecture-specific restrictions +_RECOMMENDED_FP8_REPLACEMENTS = { + "gfx942": { + torch.float8_e4m3fn: torch.float8_e4m3fnuz, + torch.float8_e5m2: torch.float8_e5m2fnuz, + }, +} + +def get_recommended_fp8_dtype(x): + dtype = x.dtype if isinstance(x, torch.Tensor) else x + if not is_dtype_fp8(dtype): + return dtype arch = get_arch() - if arch == "gfx942": - if dtype == torch.float8_e4m3fn: - replacement_dtype = torch.float8_e4m3fnuz - elif dtype == torch.float8_e5m2: - replacement_dtype = torch.float8_e5m2fnuz - else: - replacement_dtype = None - if replacement_dtype is not None: - raise TypeError(f"On {arch} use {replacement_dtype} instead of {dtype}") - return True + return _RECOMMENDED_FP8_REPLACEMENTS.get(arch, {}).get(dtype, dtype) + +def is_fp8(x) -> bool: + """Return whether tensor(s) use FP8. + Accepts either a single tensor or a list/tuple of tensors. -def is_fp8(x): - return is_dtype_fp8(x.dtype) + Rules: + * Single tensor: return True if FP8 (after arch validation), else False. + * Multiple tensors: + - If all tensors are FP8 -> return True. + - If none are FP8 -> return False. + - If a mix of FP8 and non-FP8 -> raise ValueError. + + Empty list/tuple returns False. + """ + + def _is_fp8_single(t: torch.Tensor) -> bool: + if is_dtype_fp8(t.dtype): + arch = get_arch() + if arch not in ("gfx942", "gfx950"): + raise RuntimeError( + f"{arch} is not in the list of supported architectures for FP8" + ) + return True + return False + + if isinstance(x, (list, tuple)): + if len(x) == 0: + return False + flags = [_is_fp8_single(t) for t in x] + if all(flags): + return True + if not any(flags): + return False + raise ValueError( + "Mixed FP8 and non-FP8 tensors provided; either all or none must be FP8." + ) + else: + return _is_fp8_single(x) @triton.jit @@ -1772,9 +1809,4 @@ def is_rdna(): "gfx1102", "gfx1200", "gfx1201", - ) - - -@functools.cache -def arch_supports_fp8(): - return is_hip() and get_arch() in ("gfx942") + ) \ No newline at end of file From baa633069b81c4970a75bf60a5d89dcea4e83eba Mon Sep 17 00:00:00 2001 From: Michael Date: Fri, 3 Oct 2025 20:29:46 -0500 Subject: [PATCH 10/33] use torch.float32 for fp8 kernel --- flash_attn/flash_attn_triton_amd/interface_v3.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index 7c7b92424c1..4045115e55f 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -247,7 +247,8 @@ def fwd( ) if out is None: - out_dtype = torch.bfloat16 if is_fp8([q, k, v]) else q.dtype + # NOTE: Using types that are lower precision than float32 such as bfloat16 for fp8 causes mismatches on a small set of tests. + out_dtype = torch.float32 if is_fp8([q, k, v]) else q.dtype if layout == "bshd": out = torch.zeros( q.shape[0], From f3ed846e45f191a54d7522403e8780c6f5cc2d2a Mon Sep 17 00:00:00 2001 From: Michael Date: Fri, 3 Oct 2025 23:53:00 -0500 Subject: [PATCH 11/33] add both best fp16 and fp8 config --- .../flash_attn_triton_amd/fwd_prefill.py | 62 +++++++++++-------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 1ee7c3e5b8e..c566e2c3c80 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -25,7 +25,8 @@ tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) -def get_fwd_configs(autotune: bool, use_fallback: bool = True): +def get_fwd_configs(autotune: bool): + configs = [] keys = [ "IS_CAUSAL", "dropout_p", @@ -38,55 +39,64 @@ def get_fwd_configs(autotune: bool, use_fallback: bool = True): "HK", ] - # default configs - if not autotune: - # TODO: don't use fallback config used for function correctness testing due to some configs leading error on the scale of 1e-1. - if use_fallback: - cfg = triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) + # fallback config + if False: + configs.append(triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + )) + return configs, keys - # get best config for the architecture + # get best config for the architecture + if not autotune: arch = get_arch() if arch == "gfx950": - cfg = triton.Config( + configs.append(triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, num_warps=4, - ) + )) elif arch == "gfx942": if get_cu_count() < 304: - cfg = triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "waves_per_eu": 1, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, + configs.extend( + [ + # best fp8 config + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + # best f16 config + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=2, + num_warps=4, + ) + ] ) else: - cfg = triton.Config( + configs.append(triton.Config( {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, num_warps=4, - ) + )) else: - cfg = triton.Config( + configs.append(triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, num_warps=4, - ) + )) - return [cfg], keys + return configs, keys # ===================== Autotune Sweep ===================== - BLOCK_M_OPTIONS = [128, 64, 32, 16] - BLOCK_N_OPTIONS = [128, 64, 32, 16] + BLOCK_M_OPTIONS = [128, 64, 32] + BLOCK_N_OPTIONS = [128, 64, 32] NUM_WARPS_OPTIONS = [2, 4, 8] NUM_STAGES_OPTIONS = [1, 2] WAVES_PER_EU_OPTIONS = [4, 2, 1] PRE_LOAD_V_OPTIONS = [False] - - configs = [] for bm in BLOCK_M_OPTIONS: for bn in BLOCK_N_OPTIONS: for waves in WAVES_PER_EU_OPTIONS: From efa901bff60c24667b026fdece42364b99ae2944 Mon Sep 17 00:00:00 2001 From: Michael Date: Mon, 6 Oct 2025 23:35:53 -0500 Subject: [PATCH 12/33] tune fp8 backward --- flash_attn/flash_attn_triton_amd/bwd.py | 309 ++++++++++++------ .../flash_attn_triton_amd/interface_v3.py | 58 ++-- 2 files changed, 244 insertions(+), 123 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index e47adae3321..916f96e8bcc 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -2,6 +2,7 @@ import torch import triton # type: ignore import triton.language as tl # type: ignore +import warnings from typing import Literal, Optional from .utils import ( DEBUG, @@ -11,8 +12,10 @@ compute_fp8_scaling_factors, create_dropout_mask, create_dropout_mask_varlen, + get_cu_count, is_cdna, is_fp8, + get_arch, ) # NOTE: triton fails to import tl.constexprs so create them here for the file @@ -21,46 +24,65 @@ def get_bwd_configs(autotune: bool): + # keys + preprocess_autotune_keys = [ + "max_seqlen_q", + "ACTUAL_HEAD_DIM", "IS_VARLEN", + ] + + causal_autotune_keys = [ + "dropout_p", "max_seqlen_q", "max_seqlen_k", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + + noncausal_autotune_keys = [ + "dropout_p", "max_seqlen_q", "max_seqlen_k", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + # default config if not autotune: - # preprocess params - PRE_BLOCK = 64 - PRE_WAVES_PER_EU=2 - PRE_NUM_STAGES=2 - PRE_NUM_WARPS=8 - + arch = get_arch() # configs for the kernels - preprocess_autotune_configs = [ - triton.Config({"PRE_BLOCK": PRE_BLOCK, "waves_per_eu": PRE_WAVES_PER_EU}, num_stages=PRE_NUM_STAGES, num_warps=PRE_NUM_WARPS), - ] - preprocess_autotune_keys = [ - "max_seqlen_q", - "ACTUAL_HEAD_DIM", "IS_VARLEN", - ] - - # main params - NUM_STAGES=1 - NUM_WARPS= 4 - WAVES_PER_EU = 1 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 64 - BLK_SLICE_FACTOR = 2 - MATRIX_INSTR_NONKDIM=16 - assert BLOCK_N1 == BLOCK_M2 - - causal_autotune_configs = [ - triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - causal_autotune_keys = [ - "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] - noncausal_autotune_configs = [ - triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), - ] - noncausal_autotune_keys = [ - "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] + if arch == "gfx942": + if get_cu_count() < 304: + preprocess_autotune_configs = [ + triton.Config({"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8), + triton.Config({"PRE_BLOCK": 128, "waves_per_eu": 2}, num_stages=1, num_warps=4), + ] + noncausal_autotune_configs = [ + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_M1": 64, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + ] + causal_autotune_configs = [ + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + ] + else: + preprocess_autotune_configs = [ + triton.Config({"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8), + ] + noncausal_autotune_configs = [ + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + ] + causal_autotune_configs = [ + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + ] + else: + preprocess_autotune_configs = [ + triton.Config({"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8), + ] + noncausal_autotune_configs = [ + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + ] + causal_autotune_configs = [ + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + ] + + # assert constraints + for (noncausal_cfg, causal_cfg) in zip(noncausal_autotune_configs, causal_autotune_configs): + assert noncausal_cfg.all_kwargs()["BLOCK_N1"] == noncausal_cfg.all_kwargs()["BLOCK_M2"], f"BLOCK_N1 ({noncausal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({noncausal_cfg.all_kwargs()['BLOCK_M2']})" + assert causal_cfg.all_kwargs()["BLOCK_N1"] == causal_cfg.all_kwargs()["BLOCK_M2"], f"BLOCK_N1 ({causal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({causal_cfg.all_kwargs()['BLOCK_M2']})" + return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) @@ -69,21 +91,6 @@ def get_bwd_configs(autotune: bool): PRE_WAVES_PER_EU_OPTIONS=[1, 2] PRE_NUM_STAGES_OPTIONS=[1, 2] PRE_NUM_WARPS_OPTIONS=[4, 8] - - - # Preprocess configs - preprocess_autotune_configs = [] - for pre_num_warps in PRE_NUM_WARPS_OPTIONS: - for pre_num_stages in PRE_NUM_STAGES_OPTIONS: - for pre_waves in PRE_WAVES_PER_EU_OPTIONS: - for pre_block in PRE_BLOCK_OPTIONS: - preprocess_autotune_configs.append( - triton.Config({ - "PRE_BLOCK": pre_block, - "waves_per_eu": pre_waves, - }, num_stages=pre_num_stages, num_warps=pre_num_warps) - ) - NUM_STAGES_OPTIONS = [1, 2] # og: 1 NUM_WARPS_OPTIONS = [4, 8] # og: 4 WAVES_PER_EU_OPTIONS = [1, 2] # og: 1 @@ -98,10 +105,22 @@ def get_bwd_configs(autotune: bool): 32, 64 ] BLK_SLICE_FACTOR_OPTIONS = [2] # og: 2 - - # build configs + + # ==================== sweep configs ================================ + preprocess_autotune_configs = [] causal_autotune_configs = [] - noncausal_autotune_configs = [] + noncausal_autotune_configs = [] + for pre_num_warps in PRE_NUM_WARPS_OPTIONS: + for pre_num_stages in PRE_NUM_STAGES_OPTIONS: + for pre_waves in PRE_WAVES_PER_EU_OPTIONS: + for pre_block in PRE_BLOCK_OPTIONS: + preprocess_autotune_configs.append( + triton.Config({ + "PRE_BLOCK": pre_block, + "waves_per_eu": pre_waves, + }, num_stages=pre_num_stages, num_warps=pre_num_warps) + ) + for num_warps in NUM_WARPS_OPTIONS: for num_stages in NUM_STAGES_OPTIONS: for waves in WAVES_PER_EU_OPTIONS: @@ -135,21 +154,6 @@ def get_bwd_configs(autotune: bool): }, num_stages=num_stages, num_warps=num_warps) ) - # kernel keys - preprocess_autotune_keys = [ - "max_seqlen_q", - "ACTUAL_HEAD_DIM", "IS_VARLEN", - ] - - causal_autotune_keys = [ - "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] - - noncausal_autotune_keys = [ - "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", - ] return (preprocess_autotune_configs, preprocess_autotune_keys), \ (causal_autotune_configs, causal_autotune_keys), \ @@ -3786,8 +3790,8 @@ def attention_backward_triton_split_fused_no_atomics_impl( q.device == k.device == v.device == o.device == do.device == softmax_lse.device ), f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}, do={do.device}, softmax_lse={softmax_lse.device}" assert ( - q.dtype == k.dtype == v.dtype == do.dtype - ), "q, k, v, do must have the same dtype" + q.dtype == k.dtype == v.dtype + ), "q, k, v must have the same dtype" current_device = torch.cuda.current_device() assert ( q.is_cuda and q.device.index == current_device @@ -3985,21 +3989,68 @@ def attention_backward_triton_split_fused_no_atomics_impl( IS_FP8 = is_fp8([q, k, v]) if IS_FP8: FP8_MAX = torch.finfo(q.dtype).max + + # Check and create default descale tensors if not provided (for inputs) + if (descale_q is None) or (descale_k is None) or (descale_v is None) or (descale_do is None): + warnings.warn( + "FP8 tensors detected but descale factors not provided. Using default scale of 1.0. " + "Note: Backward pass does not support proper FP8 descaling yet.", + UserWarning, + ) + # Create default descale tensors if not provided + if descale_q is None: + descale_q = torch.ones( + batch, nheads_q, dtype=torch.float32, device=q.device + ) + if descale_k is None: + descale_k = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + if descale_v is None: + descale_v = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + if descale_do is None: + descale_do = torch.ones( + batch, nheads_q, dtype=torch.float32, device=q.device + ) + # we already asserted that do, q, k, v all have the same dtype, so no need to check each one if is_fp8(o): FP8_OUTPUT = True - assert ( - descale_o is not None - ), f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." - assert ( - descale_dq is not None - ), f"descale_dq is None. In fp8, you need to pass a tensor for descale_dq along with a tensor dq." - assert ( - descale_dk is not None - ), f"descale_dk is None. In fp8, you need to pass a tensor for descale_dk along with a tensor dk." - assert ( - descale_dv is not None - ), f"descale_dv is None. In fp8, you need to pass a tensor for descale_dv along with a tensor dv." + # Create default descale tensors for outputs if not provided + if descale_o is None: + warnings.warn( + "FP8 output tensor 'o' detected but descale_o not provided. Using default scale of 1.0", + UserWarning, + ) + descale_o = torch.ones( + batch, nheads_q, dtype=torch.float32, device=q.device + ) + if descale_dq is None: + warnings.warn( + "FP8 backward requires descale_dq but not provided. Using default scale of 1.0", + UserWarning, + ) + descale_dq = torch.ones( + batch, nheads_q, dtype=torch.float32, device=q.device + ) + if descale_dk is None: + warnings.warn( + "FP8 backward requires descale_dk but not provided. Using default scale of 1.0", + UserWarning, + ) + descale_dk = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + if descale_dv is None: + warnings.warn( + "FP8 backward requires descale_dv but not provided. Using default scale of 1.0", + UserWarning, + ) + descale_dv = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) else: FP8_OUTPUT = False @@ -4010,7 +4061,7 @@ def attention_backward_triton_split_fused_no_atomics_impl( stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None if DEBUG: - print(f"FP8 path triggered (FP8_OUTPUT={FP8_OUTPUT})") + print(f"FP8 path triggered in bwd.py (FP8_OUTPUT={FP8_OUTPUT})") else: FP8_MAX = None FP8_OUTPUT = False @@ -4334,6 +4385,41 @@ def attention_backward_triton_fused_atomics_impl( IS_FP8 = is_fp8([q, k, v]) if IS_FP8: FP8_MAX = torch.finfo(q.dtype).max + + # Check and create default descale tensors if not provided + if (descale_q is None) or (descale_k is None) or (descale_v is None) or (descale_do is None): + warnings.warn( + "FP8 tensors detected but descale factors not provided. Using default scale of 1.0. " + "Note: Backward pass does not support proper FP8 descaling yet.", + UserWarning, + ) + # Determine batch size for creating default descale tensors + if cu_seqlens_q is not None: + batch = len(cu_seqlens_q) - 1 + else: + batch = q.shape[0] + + nheads_q = q.shape[1] if cu_seqlens_q is not None else q.shape[2] + nheads_k = k.shape[1] if cu_seqlens_q is not None else k.shape[2] + + # Create default descale tensors if not provided + if descale_q is None: + descale_q = torch.ones( + batch, nheads_q, dtype=torch.float32, device=q.device + ) + if descale_k is None: + descale_k = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + if descale_v is None: + descale_v = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + if descale_do is None: + descale_do = torch.ones( + batch, nheads_q, dtype=torch.float32, device=q.device + ) + descale_strides = ( descale_q.stride(0), descale_k.stride(0), @@ -4342,7 +4428,7 @@ def attention_backward_triton_fused_atomics_impl( ) if DEBUG: - print(f"FP8 path triggered") + print(f"FP8 path triggered in bwd.py (fused_atomics)") else: FP8_MAX = None stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( @@ -4781,16 +4867,38 @@ def attention_backward_triton_impl( call ONLY this function going forward. mode: 'fused_atomics' or 'fused_no_atomics'; layout: 'bshd' or 'thd'; use_exp2 retained for parity. """ - # Enforce supported dtypes (mirror Hopper behavior: FP8 forward-only) - supported_dtypes = {torch.float16, torch.bfloat16, torch.float32} - for name, t in {"q": q, "k": k, "v": v, "o": o, "do": do}.items(): - if t.dtype not in supported_dtypes: - raise TypeError( - f"Backward only supports fp16/bf16/fp32; tensor '{name}' has dtype {t.dtype}" - ) + # Allow FP8 dtypes and handle gradient tensor dtype casting + dq_original, dk_original, dv_original = None, None, None + do_original = None + + if is_fp8([q, k, v]): + warnings.warn( + "FP8 tensors detected in backward pass. Backward pass supports FP8 inputs but " + "descaling factors will default to 1.0 if not provided.", + UserWarning, + ) + + # For FP8 backward, we need dout to be FP8 for the dot products in the kernel + # The kernel does: tl.dot(v, tl.trans(do)) which requires matching FP8 dtypes + if do.dtype != q.dtype: + do_original = do + # Cast dout to the same FP8 dtype as q/k/v + do = do.to(q.dtype) + + # For the output gradients (dq, dk, dv), we compute in float32 for precision + # and convert back at the end + if dq.dtype != torch.float32: + dq_original = dq + dq = torch.empty(dq.shape, dtype=torch.float32, device=dq.device) + if dk.dtype != torch.float32: + dk_original = dk + dk = torch.empty(dk.shape, dtype=torch.float32, device=dk.device) + if dv.dtype != torch.float32: + dv_original = dv + dv = torch.empty(dv.shape, dtype=torch.float32, device=dv.device) if mode == "fused_atomics": - return attention_backward_triton_fused_atomics_impl( + delta = attention_backward_triton_fused_atomics_impl( do, q, k, @@ -4819,7 +4927,7 @@ def attention_backward_triton_impl( None, ) elif mode == "fused_no_atomics": - return attention_backward_triton_split_fused_no_atomics_impl( + delta = attention_backward_triton_split_fused_no_atomics_impl( do, q, k, @@ -4856,3 +4964,14 @@ def attention_backward_triton_impl( raise ValueError( f"Unknown backward mode '{mode}'. Expected 'fused_atomics' or 'fused_no_atomics'." ) + + # Copy float32 gradients back to original FP8 tensors if needed + # Note: This conversion happens only once at the end, not in a loop + if dq_original is not None: + dq_original.copy_(dq.to(dq_original.dtype)) + if dk_original is not None: + dk_original.copy_(dk.to(dk_original.dtype)) + if dv_original is not None: + dv_original.copy_(dv.to(dv_original.dtype)) + + return delta diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index 4045115e55f..7490c2f4aad 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -61,13 +61,13 @@ def fwd( if DEBUG: print() print("interface_fa_v3.py::fwd inputs") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("k_new:", k_new, k_new.shape if k_new is not None else None) - print("v_new:", v_new, v_new.shape if v_new is not None else None) - print("qv:", qv, qv.shape if qv is not None else None) - print("out:", out, out.shape if out is not None else None) + print("q:", q.dtype if q is not None else None, q.shape) + print("k:", k.dtype if k is not None else None, k.shape) + print("v:", v.dtype if v is not None else None, v.shape) + print("k_new:", k_new.dtype if k_new is not None else None, k_new.shape if k_new is not None else None) + print("v_new:", v_new.dtype if v_new is not None else None, v_new.shape if v_new is not None else None) + print("qv:", qv.dtype if qv is not None else None, qv.shape if qv is not None else None) + print("out:", out.dtype if out is not None else None, out.shape if out is not None else None) print( "cu_seqlens_q:", cu_seqlens_q, @@ -120,13 +120,13 @@ def fwd( seqlens_rotary.shape if seqlens_rotary is not None else None, ) print( - "q_descale:", q_descale, q_descale.shape if q_descale is not None else None + "q_descale:", q_descale.dtype if q_descale is not None else None, q_descale.shape if q_descale is not None else None ) print( - "k_descale:", k_descale, k_descale.shape if k_descale is not None else None + "k_descale:", k_descale.dtype if k_descale is not None else None, k_descale.shape if k_descale is not None else None ) print( - "v_descale:", v_descale, v_descale.shape if v_descale is not None else None + "v_descale:", v_descale.dtype if v_descale is not None else None, v_descale.shape if v_descale is not None else None ) print("softmax_scale:", softmax_scale) print("causal:", causal) @@ -411,8 +411,8 @@ def fwd( if DEBUG: print("interface_fa_v3.py::fwd outputs") - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("out:", out.dtype if out is not None else None, out.shape if out is not None else None) + print("softmax_lse:", softmax_lse.dtype if softmax_lse is not None else None, softmax_lse.shape if softmax_lse is not None else None) # Return format compatible with v3 # V3 returns (out, softmax_lse, *rest) where rest can be empty or contain additional outputs @@ -452,15 +452,15 @@ def bwd( if DEBUG: print() print("interface_fa_v3.py::bwd inputs") - print("dout:", dout, dout.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape if dq is not None else None) - print("dk:", dk, dk.shape if dk is not None else None) - print("dv:", dv, dv.shape if dv is not None else None) + print("dout:", dout.dtype if dout is not None else None, dout.shape if dout is not None else None) + print("q:", q.dtype if q is not None else None, q.shape if q is not None else None) + print("k:", k.dtype if k is not None else None, k.shape if k is not None else None) + print("v:", v.dtype if v is not None else None, v.shape if v is not None else None) + print("out:", out.dtype if out is not None else None, out.shape if out is not None else None) + print("softmax_lse:", softmax_lse.dtype if softmax_lse is not None else None, softmax_lse.shape if softmax_lse is not None else None) + print("dq:", dq.dtype if dq is not None else None, dq.shape if dq is not None else None) + print("dk:", dk.dtype if dk is not None else None, dk.shape if dk is not None else None) + print("dv:", dv.dtype if dv is not None else None, dv.shape if dv is not None else None) print( "cu_seqlens_q:", cu_seqlens_q, @@ -502,9 +502,11 @@ def bwd( ) # Initialize gradient tensors if not provided - dq = torch.zeros_like(q) if dq is None else dq.zero_() - dk = torch.zeros_like(k) if dk is None else dk.zero_() - dv = torch.zeros_like(v) if dv is None else dv.zero_() + # NOTE: Using types that are lower precision than float32 such as bfloat16 for fp8 causes mismatches on a small set of tests. + grad_dtype = torch.float32 if is_fp8([q, k, v]) else q.dtype + dq = torch.zeros_like(q, dtype=grad_dtype) if dq is None else dq.zero_() + dk = torch.zeros_like(k, dtype=grad_dtype) if dk is None else dk.zero_() + dv = torch.zeros_like(v, dtype=grad_dtype) if dv is None else dv.zero_() # Determine layout based on cu_seqlens if cu_seqlens_q is not None and cu_seqlens_k is not None: @@ -556,10 +558,10 @@ def bwd( if DEBUG: print("interface_fa_v3.py::bwd outputs") - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) - print("delta:", delta, delta.shape if delta is not None else None) + print("dq:", dq.dtype if dq is not None else None, dq.shape if dq is not None else None) + print("dk:", dk.dtype if dk is not None else None, dk.shape if dk is not None else None) + print("dv:", dv.dtype if dv is not None else None, dv.shape if dv is not None else None) + print("delta:", delta.dtype if delta is not None else None, delta.shape if delta is not None else None) # V3 expects (dq, dk, dv, softmax_d, *rest) # delta is the softmax_d in this case From 0038a5c5f64489cf890597daf2b1ca33bfccb705 Mon Sep 17 00:00:00 2001 From: Michael Date: Tue, 7 Oct 2025 15:45:52 -0500 Subject: [PATCH 13/33] descale factors should be b, hk --- flash_attn/flash_attn_triton_amd/bwd.py | 82 ++++++++----------------- 1 file changed, 26 insertions(+), 56 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 916f96e8bcc..52a7e7e4526 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -1847,7 +1847,9 @@ def _bwd_kernel_fused_atomics_dkdvdq_noncausal( ) if IS_FP8: - descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) @@ -2048,7 +2050,9 @@ def _bwd_kernel_fused_atomics_dkdv_noncausal( ) if IS_FP8: - descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) @@ -2234,7 +2238,9 @@ def _bwd_kernel_fused_atomics_dq_noncausal( # FP8 if IS_FP8: - descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) @@ -2847,7 +2853,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, USE_SEQUSED: tl.constexpr, # Add flag for seqused DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, @@ -2999,7 +3004,9 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b ) if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) @@ -3217,7 +3224,9 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) @@ -3431,7 +3440,6 @@ def bwd_kernel_noncausal( USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, - FP8_OUTPUT: tl.constexpr, USE_SEQUSED: tl.constexpr, # Add flag for seqused DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, @@ -3541,7 +3549,9 @@ def bwd_kernel_noncausal( ) if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) @@ -3662,7 +3672,9 @@ def bwd_kernel_noncausal( m = m[:, None] if IS_FP8: - descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) @@ -3998,9 +4010,10 @@ def attention_backward_triton_split_fused_no_atomics_impl( UserWarning, ) # Create default descale tensors if not provided + # For GQA/MQA, q_descale should be shaped (batch, nheads_k) to match forward pass if descale_q is None: descale_q = torch.ones( - batch, nheads_q, dtype=torch.float32, device=q.device + batch, nheads_k, dtype=torch.float32, device=q.device ) if descale_k is None: descale_k = torch.ones( @@ -4015,59 +4028,18 @@ def attention_backward_triton_split_fused_no_atomics_impl( batch, nheads_q, dtype=torch.float32, device=q.device ) - # we already asserted that do, q, k, v all have the same dtype, so no need to check each one - if is_fp8(o): - FP8_OUTPUT = True - # Create default descale tensors for outputs if not provided - if descale_o is None: - warnings.warn( - "FP8 output tensor 'o' detected but descale_o not provided. Using default scale of 1.0", - UserWarning, - ) - descale_o = torch.ones( - batch, nheads_q, dtype=torch.float32, device=q.device - ) - if descale_dq is None: - warnings.warn( - "FP8 backward requires descale_dq but not provided. Using default scale of 1.0", - UserWarning, - ) - descale_dq = torch.ones( - batch, nheads_q, dtype=torch.float32, device=q.device - ) - if descale_dk is None: - warnings.warn( - "FP8 backward requires descale_dk but not provided. Using default scale of 1.0", - UserWarning, - ) - descale_dk = torch.ones( - batch, nheads_k, dtype=torch.float32, device=q.device - ) - if descale_dv is None: - warnings.warn( - "FP8 backward requires descale_dv but not provided. Using default scale of 1.0", - UserWarning, - ) - descale_dv = torch.ones( - batch, nheads_k, dtype=torch.float32, device=q.device - ) - else: - FP8_OUTPUT = False - stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None - stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None if DEBUG: - print(f"FP8 path triggered in bwd.py (FP8_OUTPUT={FP8_OUTPUT})") + print(f"FP8 path triggered in bwd.py") else: FP8_MAX = None - FP8_OUTPUT = False stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( - stride_descale_o_z - ) = stride_descale_do_z = None + stride_descale_do_z + ) = None # alibi setup use_alibi, (stride_az, stride_ah) = ( @@ -4252,7 +4224,6 @@ def attention_backward_triton_split_fused_no_atomics_impl( USE_EXP2=use_exp2, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, USE_SEQUSED=( seqused_q is not None or seqused_k is not None ), # Add flag for seqused @@ -4342,7 +4313,6 @@ def attention_backward_triton_split_fused_no_atomics_impl( USE_EXP2=use_exp2, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, - FP8_OUTPUT=FP8_OUTPUT, USE_SEQUSED=( seqused_q is not None or seqused_k is not None ), # Add flag for seqused From 449471f576b7eee838921e3b460d9a341ab7267b Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 8 Oct 2025 12:33:49 -0500 Subject: [PATCH 14/33] fp8 bwd working on all primus configs --- flash_attn/flash_attn_triton_amd/bwd.py | 292 +++++------------------- 1 file changed, 52 insertions(+), 240 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 52a7e7e4526..1f33b8483a1 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -184,10 +184,8 @@ def _bwd_fused_atomics_preprocess( stride_delta_b, stride_delta_h, stride_delta_m, - stride_descale_do_z, cu_seqlens_q, max_seqlen_q, - descale_do_ptr, BLOCK_M: tl.constexpr, BLOCK_D_MODEL: tl.constexpr, BLOCK_D_MODEL_POW2: tl.constexpr, @@ -234,13 +232,8 @@ def _bwd_fused_atomics_preprocess( do = tl.load(do_ptr + offs, mask=mask, other=0.0) # compute and write-back to delta - if IS_FP8: - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hid) - - # NOTE: do is in the fp8 range and o is not in fp8 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + # NOTE: Both o and do are FP32 + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) offs_delta = ( bid * stride_delta_b @@ -283,7 +276,6 @@ def _bwd_fused_atomics_dq_inner( descale_q, descale_k, descale_v, - descale_do, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_D_MODEL: tl.constexpr, @@ -352,7 +344,7 @@ def _bwd_fused_atomics_dq_inner( # dp if IS_FP8: - dp = tl.dot(do, vT) * descale_do * descale_v + dp = tl.dot(do.to(vT.type.element_ty), vT) * descale_v else: dp = tl.dot(do, vT) @@ -366,12 +358,7 @@ def _bwd_fused_atomics_dq_inner( # dq # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. if IS_FP8: - scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) - dq += ( - tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) - * descale_ds - * descale_k - ) + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) * descale_k else: dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) @@ -411,7 +398,6 @@ def _bwd_fused_atomics_dkdv_inner( descale_q, descale_k, descale_v, - descale_do, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_D_MODEL: tl.constexpr, @@ -502,34 +488,16 @@ def _bwd_fused_atomics_dkdv_inner( # dV if ENABLE_DROPOUT: pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors( - pT_dropout, FP8_MAX - ) - dv += ( - tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) - * descale_p_dropout - * descale_do - ) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += ( - tl.dot((pT * scale_pT).to(do.type.element_ty), do) - * descale_pT - * descale_do - ) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) + dv += tl.dot(pT.to(do.type.element_ty), do) # Load delta Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) # Compute dP and dS if IS_FP8: - dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + dpT = tl.dot(v, tl.trans(do.to(v.type.element_ty))) * descale_v else: dpT = tl.dot(v, tl.trans(do)) @@ -541,14 +509,13 @@ def _bwd_fused_atomics_dkdv_inner( # compute dk if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += ( - tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) - * descale_dsT - * descale_q - ) + # Rewrite dk += dsT @ qT.T as dk += (qT @ dsT.T).T + # This puts FP8 tensor (qT) on LHS of dot product + # Cast the transposed dsT to FP8 to match qT's dtype + dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) + dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q else: - dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + dk += tl.dot(dsT, tl.trans(qT)) # increment pointers curr_m += step_m @@ -591,7 +558,6 @@ def _bwd_fused_atomics_dkdvdq_inner( descale_q, descale_k, descale_v, - descale_do, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_D_MODEL: tl.constexpr, @@ -706,34 +672,16 @@ def _bwd_fused_atomics_dkdvdq_inner( # dV if ENABLE_DROPOUT: pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors( - pT_dropout, FP8_MAX - ) - dv += ( - tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) - * descale_p_dropout - * descale_do - ) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += ( - tl.dot((pT * scale_pT).to(do.type.element_ty), do) - * descale_pT - * descale_do - ) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) + dv += tl.dot(pT.to(do.type.element_ty), do) # Load delta Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) # Compute dP and dS if IS_FP8: - dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + dpT = tl.dot(v, tl.trans(do.to(v.type.element_ty))) * descale_v else: dpT = tl.dot(v, tl.trans(do)) @@ -745,24 +693,23 @@ def _bwd_fused_atomics_dkdvdq_inner( # compute dk if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += ( - tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) - * descale_dsT - * descale_q - ) + # Rewrite dk += dsT @ qT.T as dk += (qT @ dsT.T).T + # This puts FP8 tensor (qT) on LHS of dot product + # Cast the transposed dsT to FP8 to match qT's dtype + dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) + dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q else: - dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + dk += tl.dot(dsT, tl.trans(qT)) # We can compute the dq_partial here and do a atomic add to the correct memory location # NOTE: Possible problems with the atomic add: contention, is inside a loop which has achieved bad perf before # (BLOCK_M, BLOCK_N) x (BLOCK_N, D) if IS_FP8: dq_partial = ( - tl.dot((dsT * scale_dsT).to(k.dtype).T, k) * descale_dsT * descale_k + tl.dot(dsT.to(k.type.element_ty).T, k) * descale_k ) else: - dq_partial = tl.dot(dsT.to(k.dtype).T, k) + dq_partial = tl.dot(dsT.to(k.type.element_ty).T, k) tl.atomic_add( dq_ptrs, dq_partial * sm_scale, @@ -819,7 +766,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -831,7 +777,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( descale_q_ptr, descale_k_ptr, descale_v_ptr, - descale_do_ptr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, BATCH, @@ -994,11 +939,8 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( descale_v = tl.load( descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx ) - descale_do = tl.load( - descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx - ) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 # if unaligned start_m is negative, the current N-tile has no block on the # diagonal of causal mask, so everything have no causal mask @@ -1034,7 +976,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user MASK_BLOCK_M, BLOCK_N, # block dim BLOCK_D_MODEL, @@ -1082,7 +1023,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user BLOCK_M, BLOCK_N, # block dim BLOCK_D_MODEL, @@ -1148,7 +1088,6 @@ def _bwd_kernel_fused_atomics_dkdv_causal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1160,7 +1099,6 @@ def _bwd_kernel_fused_atomics_dkdv_causal( descale_q_ptr, descale_k_ptr, descale_v_ptr, - descale_do_ptr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, BLOCK_M: tl.constexpr, @@ -1312,11 +1250,8 @@ def _bwd_kernel_fused_atomics_dkdv_causal( descale_v = tl.load( descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx ) - descale_do = tl.load( - descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx - ) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 # if start_m is negative, the current N-tile has no block on the # diagonal of causal mask, so everything have no causal mask @@ -1349,7 +1284,6 @@ def _bwd_kernel_fused_atomics_dkdv_causal( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user MASK_BLOCK_M, BLOCK_N, # block dim BLOCK_D_MODEL, @@ -1392,7 +1326,6 @@ def _bwd_kernel_fused_atomics_dkdv_causal( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user BLOCK_M, BLOCK_N, # block dim BLOCK_D_MODEL, @@ -1456,7 +1389,6 @@ def _bwd_kernel_fused_atomics_dq_causal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1468,7 +1400,6 @@ def _bwd_kernel_fused_atomics_dq_causal( descale_q_ptr, descale_k_ptr, descale_v_ptr, - descale_do_ptr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, BLOCK_M: tl.constexpr, @@ -1587,11 +1518,8 @@ def _bwd_kernel_fused_atomics_dq_causal( descale_v = tl.load( descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx ) - descale_do = tl.load( - descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx - ) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) # Compute dQ for masked (diagonal) blocks. @@ -1630,7 +1558,6 @@ def _bwd_kernel_fused_atomics_dq_causal( descale_q, descale_k, descale_v, - descale_do, BLOCK_M, MASK_BLOCK_N, BLOCK_D_MODEL, @@ -1674,7 +1601,6 @@ def _bwd_kernel_fused_atomics_dq_causal( descale_q, descale_k, descale_v, - descale_do, BLOCK_M, BLOCK_N, BLOCK_D_MODEL, @@ -1742,7 +1668,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_noncausal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1754,7 +1679,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_noncausal( descale_q_ptr, descale_k_ptr, descale_v_ptr, - descale_do_ptr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, BATCH, @@ -1852,9 +1776,8 @@ def _bwd_kernel_fused_atomics_dkdvdq_noncausal( descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 start_m = 0 num_steps = tl.cdiv(seqlen_q, BLOCK_M) @@ -1891,7 +1814,6 @@ def _bwd_kernel_fused_atomics_dkdvdq_noncausal( descale_q, descale_k, descale_v, - descale_do, BLOCK_M, BLOCK_N, BLOCK_D_MODEL, @@ -1956,7 +1878,6 @@ def _bwd_kernel_fused_atomics_dkdv_noncausal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1968,7 +1889,6 @@ def _bwd_kernel_fused_atomics_dkdv_noncausal( descale_q_ptr, descale_k_ptr, descale_v_ptr, - descale_do_ptr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, BLOCK_M: tl.constexpr, @@ -2055,9 +1975,8 @@ def _bwd_kernel_fused_atomics_dkdv_noncausal( descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 start_m = 0 num_steps = tl.cdiv(seqlen_q, BLOCK_M) @@ -2090,7 +2009,6 @@ def _bwd_kernel_fused_atomics_dkdv_noncausal( descale_q, descale_k, descale_v, - descale_do, BLOCK_M, BLOCK_N, BLOCK_D_MODEL, @@ -2153,7 +2071,6 @@ def _bwd_kernel_fused_atomics_dq_noncausal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -2165,7 +2082,6 @@ def _bwd_kernel_fused_atomics_dq_noncausal( descale_q_ptr, descale_k_ptr, descale_v_ptr, - descale_do_ptr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, BLOCK_M: tl.constexpr, @@ -2243,9 +2159,8 @@ def _bwd_kernel_fused_atomics_dq_noncausal( descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) - descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 start_n = 0 end_n = seqlen_k @@ -2282,7 +2197,6 @@ def _bwd_kernel_fused_atomics_dq_noncausal( descale_q, descale_k, descale_v, - descale_do, BLOCK_M, BLOCK_N, BLOCK_D_MODEL, @@ -2325,10 +2239,8 @@ def _bwd_preprocess( stride_delta_b, stride_delta_h, stride_delta_m, - stride_descale_do_z, cu_seqlens_q, max_seqlen_q, - Descale_do, PRE_BLOCK: tl.constexpr, HEAD_DIM_V: tl.constexpr, ACTUAL_HEAD_DIM_V: tl.constexpr, @@ -2376,14 +2288,8 @@ def _bwd_preprocess( o = tl.load(O + off_o, mask=mask_md, other=0.0) do = tl.load(DO + off_do, mask=mask_md, other=0.0) # compute and write-back to delta - if IS_FP8: - off_descale_do = bid * stride_descale_do_z + hid - descale_do = tl.load(Descale_do + off_descale_do) - - # NOTE: do is in the fp8 range and o is not in fp8 - delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) - else: - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + # NOTE: Both o and do are FP32 + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) off_delta = ( bid * stride_delta_b + hid * stride_delta_h @@ -2433,7 +2339,6 @@ def _bwd_dkdv_inner( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal ENABLE_DROPOUT: tl.constexpr, # activate dropout USE_ALIBI: tl.constexpr, @@ -2540,27 +2445,9 @@ def _bwd_dkdv_inner( # Compute dV. if ENABLE_DROPOUT: pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale - if IS_FP8: - scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors( - pT_dropout, FP8_MAX - ) - dv += ( - tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) - * descale_p_dropout - * descale_do - ) - else: - dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) else: - if IS_FP8: - scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) - dv += ( - tl.dot((pT * scale_pT).to(do.type.element_ty), do) - * descale_pT - * descale_do - ) - else: - dv += tl.dot(pT.to(do.type.element_ty), do) + dv += tl.dot(pT.to(do.type.element_ty), do) if DEBUG_TRITON_DETAIL: if start_n == 256: @@ -2569,7 +2456,7 @@ def _bwd_dkdv_inner( Di = tl.load(D + offs_m * stride_delta_m, mask=mask_m) # Compute dP and dS. if IS_FP8: - dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + dpT = tl.dot(v, tl.trans(do.to(v.type.element_ty))) * descale_v else: dpT = tl.dot(v, tl.trans(do)) if ENABLE_DROPOUT: @@ -2577,14 +2464,13 @@ def _bwd_dkdv_inner( delta_i = Di[None, :] dsT = pT * (dpT - delta_i) if IS_FP8: - scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) - dk += ( - tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) - * descale_dsT - * descale_q - ) + # Rewrite dk += dsT @ qT.T as dk += (qT @ dsT.T).T + # This puts FP8 tensor (qT) on LHS of dot product + # Cast the transposed dsT to FP8 to match qT's dtype + dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) + dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q else: - dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + dk += tl.dot(dsT, tl.trans(qT)) # Increment pointers. curr_m += step_m qT_ptrs += step_m * stride_qm @@ -2635,7 +2521,6 @@ def _bwd_dq_inner( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user MASK: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -2735,7 +2620,7 @@ def _bwd_dq_inner( p = tl.where(mask, p, 0.0) # Compute dP and dS. if IS_FP8: - dp = tl.dot(do, vT) * descale_do * descale_v + dp = tl.dot(do.to(vT.type.element_ty), vT) * descale_v else: dp = tl.dot(do, vT) if ENABLE_DROPOUT: @@ -2745,12 +2630,7 @@ def _bwd_dq_inner( # Compute dQ. # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. if IS_FP8: - scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) - dq += ( - tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) - * descale_ds - * descale_k - ) + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) * descale_k else: dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) # Increment pointers. @@ -2818,7 +2698,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, stride_az, stride_ah, HQ, @@ -2837,7 +2716,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b Descale_q, Descale_k, Descale_v, - Descale_do, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, BLOCK_M2: tl.constexpr, @@ -3009,9 +2887,8 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR # bound the masked operation to q len so it does not have to wast cycles @@ -3064,7 +2941,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b descale_q, descale_k, descale_v, - descale_do, MASK=True, # causal masking ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout USE_ALIBI=USE_ALIBI, @@ -3125,7 +3001,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b descale_q, descale_k, descale_v, - descale_do, MASK=False, # causal masking ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout USE_ALIBI=USE_ALIBI, @@ -3229,9 +3104,8 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) dq = _bwd_dq_inner( @@ -3273,7 +3147,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b descale_q, descale_k, descale_v, - descale_do, MASK=True, # ENABLE_DROPOUT=ENABLE_DROPOUT, USE_ALIBI=USE_ALIBI, @@ -3329,7 +3202,6 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b descale_q, descale_k, descale_v, - descale_do, MASK=False, ENABLE_DROPOUT=ENABLE_DROPOUT, USE_ALIBI=USE_ALIBI, @@ -3405,7 +3277,6 @@ def bwd_kernel_noncausal( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, stride_az, stride_ah, HQ, @@ -3424,7 +3295,6 @@ def bwd_kernel_noncausal( Descale_q, Descale_k, Descale_v, - Descale_do, BLOCK_M1: tl.constexpr, # 32 BLOCK_N1: tl.constexpr, # 128 BLOCK_M2: tl.constexpr, # 128 @@ -3554,9 +3424,8 @@ def bwd_kernel_noncausal( descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 # because there is no causal, we always start from the beginning start_m = 0 @@ -3598,7 +3467,6 @@ def bwd_kernel_noncausal( descale_q, descale_k, descale_v, - descale_do, # fp8 descale factors from user MASK=False, # causal masking ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout USE_ALIBI=USE_ALIBI, @@ -3677,9 +3545,8 @@ def bwd_kernel_noncausal( descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) - descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) else: - descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 # start can only be 0 at minimum start_n = 0 @@ -3726,7 +3593,6 @@ def bwd_kernel_noncausal( descale_q, descale_k, descale_v, - descale_do, MASK=False, ENABLE_DROPOUT=ENABLE_DROPOUT, USE_ALIBI=USE_ALIBI, @@ -3782,7 +3648,6 @@ def attention_backward_triton_split_fused_no_atomics_impl( descale_k: Optional[torch.Tensor], descale_v: Optional[torch.Tensor], descale_o: Optional[torch.Tensor], - descale_do: Optional[torch.Tensor], descale_dq: Optional[torch.Tensor], descale_dk: Optional[torch.Tensor], descale_dv: Optional[torch.Tensor], @@ -4003,10 +3868,9 @@ def attention_backward_triton_split_fused_no_atomics_impl( FP8_MAX = torch.finfo(q.dtype).max # Check and create default descale tensors if not provided (for inputs) - if (descale_q is None) or (descale_k is None) or (descale_v is None) or (descale_do is None): + if (descale_q is None) or (descale_k is None) or (descale_v is None): warnings.warn( - "FP8 tensors detected but descale factors not provided. Using default scale of 1.0. " - "Note: Backward pass does not support proper FP8 descaling yet.", + "FP8 tensors detected but descale factors not provided. Using default scale of 1.0.", UserWarning, ) # Create default descale tensors if not provided @@ -4023,22 +3887,16 @@ def attention_backward_triton_split_fused_no_atomics_impl( descale_v = torch.ones( batch, nheads_k, dtype=torch.float32, device=q.device ) - if descale_do is None: - descale_do = torch.ones( - batch, nheads_q, dtype=torch.float32, device=q.device - ) stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None - stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None if DEBUG: print(f"FP8 path triggered in bwd.py") else: FP8_MAX = None stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( - stride_descale_do_z ) = None # alibi setup @@ -4094,10 +3952,8 @@ def attention_backward_triton_split_fused_no_atomics_impl( stride_delta_b, stride_delta_h, stride_delta_m, - stride_descale_do_z, cu_seqlens_q, max_seqlen_q, - descale_do, HEAD_DIM_V=HEAD_DIM_V, ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, IS_VARLEN=IS_VARLEN, @@ -4194,7 +4050,6 @@ def attention_backward_triton_split_fused_no_atomics_impl( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, stride_az, stride_ah, nheads_q, @@ -4213,7 +4068,6 @@ def attention_backward_triton_split_fused_no_atomics_impl( descale_q, descale_k, descale_v, - descale_do, HEAD_DIM_QK=HEAD_DIM_QK, HEAD_DIM_V=HEAD_DIM_V, ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, @@ -4283,7 +4137,6 @@ def attention_backward_triton_split_fused_no_atomics_impl( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, stride_az, stride_ah, nheads_q, @@ -4302,7 +4155,6 @@ def attention_backward_triton_split_fused_no_atomics_impl( descale_q, descale_k, descale_v, - descale_do, HEAD_DIM_QK=HEAD_DIM_QK, HEAD_DIM_V=HEAD_DIM_V, ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, @@ -4346,7 +4198,6 @@ def attention_backward_triton_fused_atomics_impl( descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None, fused: bool = False, # seqused for FA v3 (currently ignored in this implementation) seqused_q: Optional[torch.Tensor] = None, @@ -4359,8 +4210,7 @@ def attention_backward_triton_fused_atomics_impl( # Check and create default descale tensors if not provided if (descale_q is None) or (descale_k is None) or (descale_v is None) or (descale_do is None): warnings.warn( - "FP8 tensors detected but descale factors not provided. Using default scale of 1.0. " - "Note: Backward pass does not support proper FP8 descaling yet.", + "FP8 tensors detected but descale factors not provided. Using default scale of 1.0.", UserWarning, ) # Determine batch size for creating default descale tensors @@ -4402,13 +4252,11 @@ def attention_backward_triton_fused_atomics_impl( else: FP8_MAX = None stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( - stride_descale_do_z ) = None descale_strides = ( stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, - stride_descale_do_z, ) IS_VARLEN = True if cu_seqlens_q is not None else False @@ -4481,7 +4329,6 @@ def attention_backward_triton_fused_atomics_impl( descale_strides[3], cu_seqlens_q, max_seqlen_q, - descale_do, BLOCK_M=PRE_BLOCK, BLOCK_D_MODEL=head_sz, BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, @@ -4556,7 +4403,6 @@ def attention_backward_triton_fused_atomics_impl( descale_q, descale_k, descale_v, - descale_do, NUM_Q_HEADS=num_q_heads, NUM_K_HEADS=num_k_heads, BATCH=batch, @@ -4601,7 +4447,6 @@ def attention_backward_triton_fused_atomics_impl( descale_q, descale_k, descale_v, - descale_do, NUM_Q_HEADS=num_q_heads, NUM_K_HEADS=num_k_heads, BATCH=batch, @@ -4649,7 +4494,6 @@ def attention_backward_triton_fused_atomics_impl( descale_q, descale_k, descale_v, - descale_do, NUM_Q_HEADS=num_q_heads, NUM_K_HEADS=num_k_heads, BLOCK_M=BLOCK_M1, @@ -4693,7 +4537,6 @@ def attention_backward_triton_fused_atomics_impl( descale_q, descale_k, descale_v, - descale_do, NUM_Q_HEADS=num_q_heads, NUM_K_HEADS=num_k_heads, BLOCK_M=BLOCK_M2, @@ -4739,7 +4582,6 @@ def attention_backward_triton_fused_atomics_impl( descale_q, descale_k, descale_v, - descale_do, NUM_Q_HEADS=num_q_heads, NUM_K_HEADS=num_k_heads, BLOCK_M=BLOCK_M1, @@ -4784,7 +4626,6 @@ def attention_backward_triton_fused_atomics_impl( descale_q, descale_k, descale_v, - descale_do, NUM_Q_HEADS=num_q_heads, NUM_K_HEADS=num_k_heads, BLOCK_M=BLOCK_M2, @@ -4844,28 +4685,9 @@ def attention_backward_triton_impl( if is_fp8([q, k, v]): warnings.warn( "FP8 tensors detected in backward pass. Backward pass supports FP8 inputs but " - "descaling factors will default to 1.0 if not provided.", + "descaling factors will default to 1.0.", UserWarning, ) - - # For FP8 backward, we need dout to be FP8 for the dot products in the kernel - # The kernel does: tl.dot(v, tl.trans(do)) which requires matching FP8 dtypes - if do.dtype != q.dtype: - do_original = do - # Cast dout to the same FP8 dtype as q/k/v - do = do.to(q.dtype) - - # For the output gradients (dq, dk, dv), we compute in float32 for precision - # and convert back at the end - if dq.dtype != torch.float32: - dq_original = dq - dq = torch.empty(dq.shape, dtype=torch.float32, device=dq.device) - if dk.dtype != torch.float32: - dk_original = dk - dk = torch.empty(dk.shape, dtype=torch.float32, device=dk.device) - if dv.dtype != torch.float32: - dv_original = dv - dv = torch.empty(dv.shape, dtype=torch.float32, device=dv.device) if mode == "fused_atomics": delta = attention_backward_triton_fused_atomics_impl( @@ -4926,7 +4748,6 @@ def attention_backward_triton_impl( None, None, None, - None, seqused_q, seqused_k, ) @@ -4934,14 +4755,5 @@ def attention_backward_triton_impl( raise ValueError( f"Unknown backward mode '{mode}'. Expected 'fused_atomics' or 'fused_no_atomics'." ) - - # Copy float32 gradients back to original FP8 tensors if needed - # Note: This conversion happens only once at the end, not in a loop - if dq_original is not None: - dq_original.copy_(dq.to(dq_original.dtype)) - if dk_original is not None: - dk_original.copy_(dk.to(dk_original.dtype)) - if dv_original is not None: - dv_original.copy_(dv.to(dv_original.dtype)) - + return delta From 88b37e98ced6da3474d5934a91bae1d110a9e0b0 Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 8 Oct 2025 17:27:07 -0500 Subject: [PATCH 15/33] tune bwd configs --- flash_attn/flash_attn_triton_amd/bwd.py | 58 ++++++++++++++++++------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 1f33b8483a1..4bf52a06ca5 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -47,6 +47,7 @@ def get_bwd_configs(autotune: bool): if arch == "gfx942": if get_cu_count() < 304: preprocess_autotune_configs = [ + triton.Config({"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=8), triton.Config({"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8), triton.Config({"PRE_BLOCK": 128, "waves_per_eu": 2}, num_stages=1, num_warps=4), ] @@ -56,6 +57,8 @@ def get_bwd_configs(autotune: bool): ] causal_autotune_configs = [ triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_M1": 64, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), ] else: preprocess_autotune_configs = [ @@ -86,7 +89,7 @@ def get_bwd_configs(autotune: bool): return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) - # params + # param options PRE_BLOCK_OPTIONS = [64, 128] # og: 128 PRE_WAVES_PER_EU_OPTIONS=[1, 2] PRE_NUM_STAGES_OPTIONS=[1, 2] @@ -95,21 +98,28 @@ def get_bwd_configs(autotune: bool): NUM_WARPS_OPTIONS = [4, 8] # og: 4 WAVES_PER_EU_OPTIONS = [1, 2] # og: 1 MATRIX_INSTR_NONKDIM_OPTIONS = [16, 32] # og: 16 - BLOCK_M1_OPTIONS = [ # og: 32 + CAUSAL_BLOCK_M1_OPTIONS = [ # og: 32 + 32, 64, + ] + CAUSAL_BLOCK_N1_M2_OPTIONS = [ # og: 128 + 64, 128, 256 + ] + CAUSAL_BLOCK_N2_OPTIONS = [ # og: 32 32, 64 ] - BLOCK_N1_M2_OPTIONS = [ # og: 128 - 64, 128 + NON_CAUSAL_BLOCK_M1_OPTIONS = [ # og: 32 + 32, 64 ] - BLOCK_N2_OPTIONS = [ # og: 32 + NON_CAUSAL_BLOCK_N1_M2_OPTIONS = [ # og: 128 + 64, 128, 256 + ] + NON_CAUSAL_BLOCK_N2_OPTIONS = [ # og: 32 32, 64 ] BLK_SLICE_FACTOR_OPTIONS = [2] # og: 2 # ==================== sweep configs ================================ preprocess_autotune_configs = [] - causal_autotune_configs = [] - noncausal_autotune_configs = [] for pre_num_warps in PRE_NUM_WARPS_OPTIONS: for pre_num_stages in PRE_NUM_STAGES_OPTIONS: for pre_waves in PRE_WAVES_PER_EU_OPTIONS: @@ -121,15 +131,15 @@ def get_bwd_configs(autotune: bool): }, num_stages=pre_num_stages, num_warps=pre_num_warps) ) + causal_autotune_configs = [] for num_warps in NUM_WARPS_OPTIONS: for num_stages in NUM_STAGES_OPTIONS: for waves in WAVES_PER_EU_OPTIONS: for matrix_instr_nonkdim in MATRIX_INSTR_NONKDIM_OPTIONS: - # Causal and non-causal configs - for m1 in BLOCK_M1_OPTIONS: - for n1 in BLOCK_N1_M2_OPTIONS: + for m1 in CAUSAL_BLOCK_M1_OPTIONS: + for n1 in CAUSAL_BLOCK_N1_M2_OPTIONS: m2 = n1 - for n2 in BLOCK_N2_OPTIONS: + for n2 in CAUSAL_BLOCK_N2_OPTIONS: # Ensure constraint assert n1 == m2, f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" @@ -144,6 +154,18 @@ def get_bwd_configs(autotune: bool): }, num_stages=num_stages, num_warps=num_warps) ) + noncausal_autotune_configs = [] + for num_warps in NUM_WARPS_OPTIONS: + for num_stages in NUM_STAGES_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for matrix_instr_nonkdim in MATRIX_INSTR_NONKDIM_OPTIONS: + for m1 in NON_CAUSAL_BLOCK_M1_OPTIONS: + for n1 in NON_CAUSAL_BLOCK_N1_M2_OPTIONS: + m2 = n1 + for n2 in NON_CAUSAL_BLOCK_N2_OPTIONS: + # Ensure constraint + assert n1 == m2, f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + for blk_slice in BLK_SLICE_FACTOR_OPTIONS: noncausal_autotune_configs.append( triton.Config({ "BLOCK_M1": m1, "BLOCK_N1": n1, @@ -153,13 +175,11 @@ def get_bwd_configs(autotune: bool): "matrix_instr_nonkdim": matrix_instr_nonkdim }, num_stages=num_stages, num_warps=num_warps) ) - return (preprocess_autotune_configs, preprocess_autotune_keys), \ (causal_autotune_configs, causal_autotune_keys), \ (noncausal_autotune_configs, noncausal_autotune_keys) - ( (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), @@ -358,7 +378,11 @@ def _bwd_fused_atomics_dq_inner( # dq # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. if IS_FP8: - dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) * descale_k + # Rewrite dq += ds @ kT.T as dq += (kT @ ds.T).T + # This puts FP8 tensor (kT) on LHS of dot product + # Cast the transposed ds to FP8 to match kT's dtype + ds_transposed = tl.trans(ds).to(kT.type.element_ty) + dq += tl.trans(tl.dot(kT, ds_transposed)) * descale_k else: dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) @@ -2630,7 +2654,11 @@ def _bwd_dq_inner( # Compute dQ. # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. if IS_FP8: - dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) * descale_k + # Rewrite dq += ds @ kT.T as dq += (kT @ ds.T).T + # This puts FP8 tensor (kT) on LHS of dot product + # Cast the transposed ds to FP8 to match kT's dtype + ds_transposed = tl.trans(ds).to(kT.type.element_ty) + dq += tl.trans(tl.dot(kT, ds_transposed)) * descale_k else: dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) # Increment pointers. From b7e3e48722f415e9a33d785cf4f1bb257d8a2c0a Mon Sep 17 00:00:00 2001 From: Michael Date: Thu, 9 Oct 2025 07:39:58 -0500 Subject: [PATCH 16/33] fa v3 tests passing --- flash_attn/flash_attn_triton_amd/bwd.py | 12 ++-- .../flash_attn_triton_amd/fwd_decode.py | 35 +++++++---- .../flash_attn_triton_amd/fwd_prefill.py | 36 ++++++++---- .../flash_attn_triton_amd/interface_v3.py | 58 ------------------- 4 files changed, 54 insertions(+), 87 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 4bf52a06ca5..e593afb1647 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -539,7 +539,7 @@ def _bwd_fused_atomics_dkdv_inner( dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q else: - dk += tl.dot(dsT, tl.trans(qT)) + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) # increment pointers curr_m += step_m @@ -723,7 +723,7 @@ def _bwd_fused_atomics_dkdvdq_inner( dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q else: - dk += tl.dot(dsT, tl.trans(qT)) + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) # We can compute the dq_partial here and do a atomic add to the correct memory location # NOTE: Possible problems with the atomic add: contention, is inside a loop which has achieved bad perf before @@ -2494,7 +2494,7 @@ def _bwd_dkdv_inner( dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q else: - dk += tl.dot(dsT, tl.trans(qT)) + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) # Increment pointers. curr_m += step_m qT_ptrs += step_m * stride_qm @@ -3924,8 +3924,7 @@ def attention_backward_triton_split_fused_no_atomics_impl( print(f"FP8 path triggered in bwd.py") else: FP8_MAX = None - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( - ) = None + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = None # alibi setup use_alibi, (stride_az, stride_ah) = ( @@ -4279,8 +4278,7 @@ def attention_backward_triton_fused_atomics_impl( print(f"FP8 path triggered in bwd.py (fused_atomics)") else: FP8_MAX = None - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = ( - ) = None + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = None descale_strides = ( stride_descale_q_z, stride_descale_k_z, diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index 5186f52ecdf..93bfbdb2623 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -1,4 +1,5 @@ import os +import warnings import torch import triton import triton.language as tl @@ -1122,17 +1123,14 @@ def attention_forward_decode_triton_impl( # FP8 support IS_FP8 = is_fp8([q, k_cache, v_cache]) if IS_FP8: - CAST_TO_REC = str(os.getenv("CAST_TO_REC", "0")).lower() in ("1", "true", "yes", "on") - if CAST_TO_REC: - rec = get_recommended_fp8_dtype(q) - if q.dtype != rec: - raise TypeError( - f"FP8 dtype mismatch: received {q.dtype}, expected recommended {rec}. " - "Convert to the recommended FP8 dtype before calling (handled in interface)." - ) + rec = get_recommended_fp8_dtype(q) + if q.dtype != rec: + warnings.warn( + f"FP8 dtype mismatch: received {q.dtype}, expected recommended {rec}. " + "Convert to the recommended FP8 dtype before calling (handled in interface).", + UserWarning, + ) if (q_descale is None) or (k_descale is None) or (v_descale is None): - import warnings - warnings.warn( "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning, @@ -1150,6 +1148,23 @@ def attention_forward_decode_triton_impl( v_descale = torch.ones( batch_size, nheads_vc, dtype=torch.float32, device=q.device ) + else: + # Enforce exact expected shapes; no reshaping or normalization. + assert ( + q_descale.dim() == 2 + and q_descale.shape[0] == batch_size + and q_descale.shape[1] == nheads_kc + ), f"q_descale expected shape ({batch_size}, {nheads_kc}) got {tuple(q_descale.shape)}" + assert ( + k_descale.dim() == 2 + and k_descale.shape[0] == batch_size + and k_descale.shape[1] == nheads_kc + ), f"k_descale expected shape ({batch_size}, {nheads_kc}) got {tuple(k_descale.shape)}" + assert ( + v_descale.dim() == 2 + and v_descale.shape[0] == batch_size + and v_descale.shape[1] == nheads_kc + ), f"v_descale expected shape ({batch_size}, {nheads_kc}) got {tuple(v_descale.shape)}" stride_q_descale_z, stride_q_descale_h = q_descale.stride() stride_k_descale_z, stride_k_descale_h = k_descale.stride() stride_v_descale_z, stride_v_descale_h = v_descale.stride() diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index c566e2c3c80..b635088e084 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -1,4 +1,5 @@ import os +import warnings import torch import triton import triton.language as tl @@ -1741,20 +1742,14 @@ def attention_forward_prefill_triton_impl( IS_FP8 = is_fp8([q, k, v]) if IS_FP8: FP8_MAX = torch.finfo(q.dtype).max + rec = get_recommended_fp8_dtype(q) + if q.dtype != rec: + warnings.warn( + f"FP8 dtype mismatch: received {q.dtype}, expected recommended {rec} for this architecture.", + UserWarning, + ) - CAST_TO_REC = str(os.getenv("CAST_TO_REC", "0")).lower() in ("1", "true", "yes", "on") - if CAST_TO_REC: - # check fp8 is the correct dtype for this architecture - rec = get_recommended_fp8_dtype(q) - if q.dtype != rec: - raise TypeError( - f"FP8 dtype mismatch: received {q.dtype}, expected recommended {rec} for this architecture. " - ) - - # Check and create default descale tensors if not provided if (q_descale is None) or (k_descale is None) or (v_descale is None): - import warnings - warnings.warn( "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning, @@ -1772,6 +1767,23 @@ def attention_forward_prefill_triton_impl( v_descale = torch.ones( batch, nheads_k, dtype=torch.float32, device=q.device ) + else: + # Enforce exact expected shapes; no reshaping or normalization. + assert ( + q_descale.dim() == 2 + and q_descale.shape[0] == batch + and q_descale.shape[1] == nheads_k + ), f"q_descale expected shape ({batch}, {nheads_k}) got {tuple(q_descale.shape)}" + assert ( + k_descale.dim() == 2 + and k_descale.shape[0] == batch + and k_descale.shape[1] == nheads_k + ), f"k_descale expected shape ({batch}, {nheads_k}) got {tuple(k_descale.shape)}" + assert ( + v_descale.dim() == 2 + and v_descale.shape[0] == batch + and v_descale.shape[1] == nheads_k + ), f"v_descale expected shape ({batch}, {nheads_k}) got {tuple(v_descale.shape)}" # o should be fp32 or fp16/bf16 assert o.dtype in [ diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index 7490c2f4aad..e7281373aac 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -269,64 +269,6 @@ def fwd( else: out = out.zero_() - if is_fp8([q, k, v]): - CAST_TO_REC = str(os.getenv("CAST_TO_REC", "0")).lower() in ("1", "true", "yes", "on") - if CAST_TO_REC: - # check recommended dtype - rec = get_recommended_fp8_dtype(q) - if rec != q.dtype: - warnings.warn( - f"Casting q,k,v from {q.dtype} to recommended {rec} for this architecture.", - UserWarning, - ) - q = q.to(rec) - k = k.to(rec) - v = v.to(rec) - if k_new is not None and is_fp8(k_new): - rec_kn = get_recommended_fp8_dtype(k_new) - if rec_kn != k_new.dtype: - k_new = k_new.to(rec_kn) - if v_new is not None and is_fp8(v_new): - rec_vn = get_recommended_fp8_dtype(v_new) - if rec_vn != v_new.dtype: - v_new = v_new.to(rec_vn) - - - if (q_descale is None) or (k_descale is None) or (v_descale is None): - warnings.warn( - "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", - UserWarning, - ) - else: - # Enforce exact expected shapes; no reshaping or normalization. - if layout == "bshd": - expected_batch = q.shape[0] - expected_q_heads = q.shape[2] - expected_kv_heads = k.shape[2] - else: # thd layout - expected_batch = ( - (len(cu_seqlens_q_local) - 1) - if cu_seqlens_q_local is not None - else 1 - ) - expected_q_heads = q.shape[1] - expected_kv_heads = k.shape[1] - - assert ( - q_descale.dim() == 2 - and q_descale.shape[0] == expected_batch - and q_descale.shape[1] == expected_kv_heads - ), f"q_descale expected shape ({expected_batch}, {expected_kv_heads}) got {tuple(q_descale.shape)}" - assert ( - k_descale.dim() == 2 - and k_descale.shape[0] == expected_batch - and k_descale.shape[1] == expected_kv_heads - ), f"k_descale expected shape ({expected_batch}, {expected_kv_heads}) got {tuple(k_descale.shape)}" - assert ( - v_descale.dim() == 2 - and v_descale.shape[0] == expected_batch - and v_descale.shape[1] == expected_kv_heads - ), f"v_descale expected shape ({expected_batch}, {expected_kv_heads}) got {tuple(v_descale.shape)}" # Handle causal mask causal_flag = bool(causal) From 1f9510da74c50e32d7e93b9455bc6b5980dc5962 Mon Sep 17 00:00:00 2001 From: Michael Date: Thu, 9 Oct 2025 07:46:50 -0500 Subject: [PATCH 17/33] better warning --- flash_attn/flash_attn_triton_amd/fwd_decode.py | 11 ++++++----- flash_attn/flash_attn_triton_amd/fwd_prefill.py | 9 +++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index 93bfbdb2623..657f0b3ddbb 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -7,6 +7,7 @@ from .utils import ( DEBUG, AUTOTUNE, + get_arch, get_padded_headsize, get_shape_and_strides_from_layout, apply_rotary, @@ -1123,12 +1124,12 @@ def attention_forward_decode_triton_impl( # FP8 support IS_FP8 = is_fp8([q, k_cache, v_cache]) if IS_FP8: - rec = get_recommended_fp8_dtype(q) - if q.dtype != rec: + rec_dtype = get_recommended_fp8_dtype(q) + if q.dtype != rec_dtype or k_cache.dtype != rec_dtype or v_cache.dtype != rec_dtype: + arch = get_arch() warnings.warn( - f"FP8 dtype mismatch: received {q.dtype}, expected recommended {rec}. " - "Convert to the recommended FP8 dtype before calling (handled in interface).", - UserWarning, + f"Use {rec_dtype} data type on {arch}. Got q: {q.dtype}, k: {k_cache.dtype}, v: {v_cache.dtype}", + UserWarning, ) if (q_descale is None) or (k_descale is None) or (v_descale is None): warnings.warn( diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index b635088e084..786fc7a2039 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -1742,11 +1742,12 @@ def attention_forward_prefill_triton_impl( IS_FP8 = is_fp8([q, k, v]) if IS_FP8: FP8_MAX = torch.finfo(q.dtype).max - rec = get_recommended_fp8_dtype(q) - if q.dtype != rec: + rec_dtype = get_recommended_fp8_dtype(q) + if q.dtype != rec_dtype or k.dtype != rec_dtype or v.dtype != rec_dtype: + arch = get_arch() warnings.warn( - f"FP8 dtype mismatch: received {q.dtype}, expected recommended {rec} for this architecture.", - UserWarning, + f"Use {rec_dtype} data type on {arch}. Got q: {q.dtype}, k: {k.dtype}, v: {v.dtype}", + UserWarning, ) if (q_descale is None) or (k_descale is None) or (v_descale is None): From d6dcef41050169f61ba3213e707a0a1af355aa64 Mon Sep 17 00:00:00 2001 From: Michael Date: Thu, 9 Oct 2025 13:14:10 -0500 Subject: [PATCH 18/33] clean up bwd launcher --- flash_attn/flash_attn_triton_amd/bwd.py | 1446 +++++++---------- .../flash_attn_triton_amd/fwd_prefill.py | 9 - flash_attn/flash_attn_triton_amd/utils.py | 2 +- 3 files changed, 609 insertions(+), 848 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index e593afb1647..8d2c1db5982 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -187,85 +187,8 @@ def get_bwd_configs(autotune: bool): ) = get_bwd_configs(AUTOTUNE) -# This function computes delta given output Out and gradient DO -# Here is the I/O shape: -# Out: (batch, nhead_q, max_seqlens_q, headDim) -# DO: (batch, nhead_q, max_seqlens_q, headDim) -# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at -@triton.jit -def _bwd_fused_atomics_preprocess( - o_ptr, - do_ptr, # noqa: E741 - delta_ptr, - stride_o_b, - stride_o_h, - stride_o_m, - stride_o_k, - stride_delta_b, - stride_delta_h, - stride_delta_m, - cu_seqlens_q, - max_seqlen_q, - BLOCK_M: tl.constexpr, - BLOCK_D_MODEL: tl.constexpr, - BLOCK_D_MODEL_POW2: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_FP8: tl.constexpr, -): - pid_m = tl.program_id(0) # seqlen - bid = tl.program_id(1) # batch - hid = tl.program_id(2) # head - - # Handle varlen - q_start = 0 - seqlen_q = max_seqlen_q - if IS_VARLEN: - q_start = tl.load(cu_seqlens_q + bid) - q_end = tl.load(cu_seqlens_q + bid + 1) - seqlen_q = q_end - q_start - else: - q_start = 0 - seqlen_q = max_seqlen_q - - # Compute offsets - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) - - # Offset O/DO by batch, head and q_start - offs = ( - bid * stride_o_b - + hid * stride_o_h - + q_start * stride_o_m - + offs_m[:, None] * stride_o_m - + offs_k[None, :] * stride_o_k - ) - - # create masks - mask_m = offs_m < seqlen_q - mask = mask_m[:, None] - PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 - if PADDED_HEAD: - mask &= offs_k[None, :] < BLOCK_D_MODEL - - # load [BLOCK_M, BLOCK_D_MODEL_POW2] - o = tl.load(o_ptr + offs, mask=mask, other=0.0) - do = tl.load(do_ptr + offs, mask=mask, other=0.0) - - # compute and write-back to delta - # NOTE: Both o and do are FP32 - delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) - - offs_delta = ( - bid * stride_delta_b - + hid * stride_delta_h - + q_start * stride_delta_m - + offs_m * stride_delta_m - ) - tl.store(delta_ptr + offs_delta, delta, mask=mask_m) - - @triton.jit -def _bwd_fused_atomics_dq_inner( +def _bwd_dq_inner_split( dq, q, K, @@ -393,7 +316,7 @@ def _bwd_fused_atomics_dq_inner( @triton.jit -def _bwd_fused_atomics_dkdv_inner( +def _bwd_dkdv_inner_split( dk, dv, Q, @@ -550,7 +473,7 @@ def _bwd_fused_atomics_dkdv_inner( @triton.jit -def _bwd_fused_atomics_dkdvdq_inner( +def _bwd_dkdvdq_inner_atomic( dk, dv, Q, @@ -745,7 +668,7 @@ def _bwd_fused_atomics_dkdvdq_inner( @triton.jit -def _bwd_kernel_fused_atomics_dkdvdq_causal( +def _bwd_kernel_fused_atomic_causal( q_ptr, k_ptr, v_ptr, @@ -968,7 +891,7 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( # if unaligned start_m is negative, the current N-tile has no block on the # diagonal of causal mask, so everything have no causal mask - dk, dv = _bwd_fused_atomics_dkdvdq_inner( + dk, dv = _bwd_dkdvdq_inner_atomic( dk, dv, # output tensors q_ptr_adj, @@ -1015,7 +938,7 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) end_m = start_m + num_steps * BLOCK_M - dk, dv = _bwd_fused_atomics_dkdvdq_inner( + dk, dv = _bwd_dkdvdq_inner_atomic( dk, dv, # output tensors q_ptr_adj, @@ -1072,7 +995,7 @@ def _bwd_kernel_fused_atomics_dkdvdq_causal( @triton.jit -def _bwd_kernel_fused_atomics_dkdv_causal( +def _bwd_kernel_split_dkdv_causal( q_ptr, k_ptr, v_ptr, @@ -1279,7 +1202,7 @@ def _bwd_kernel_fused_atomics_dkdv_causal( # if start_m is negative, the current N-tile has no block on the # diagonal of causal mask, so everything have no causal mask - dk, dv = _bwd_fused_atomics_dkdv_inner( + dk, dv = _bwd_dkdv_inner_split( dk, dv, # output tensors q_ptr_adj, @@ -1321,7 +1244,7 @@ def _bwd_kernel_fused_atomics_dkdv_causal( num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) end_m = start_m + num_steps * BLOCK_M - dk, dv = _bwd_fused_atomics_dkdv_inner( + dk, dv = _bwd_dkdv_inner_split( dk, dv, # output tensors q_ptr_adj, @@ -1374,7 +1297,7 @@ def _bwd_kernel_fused_atomics_dkdv_causal( @triton.jit -def _bwd_kernel_fused_atomics_dq_causal( +def _bwd_kernel_split_dq_causal( q_ptr, k_ptr, v_ptr, @@ -1551,7 +1474,7 @@ def _bwd_kernel_fused_atomics_dq_causal( # but inside each call to _bwd_dq_inner, from left to right), but that's # not due to anything important. I just wanted to reuse the loop # structure for dK & dV above as much as possible. - dq = _bwd_fused_atomics_dq_inner( + dq = _bwd_dq_inner_split( dq, q, k_ptr_adj, @@ -1594,7 +1517,7 @@ def _bwd_kernel_fused_atomics_dq_causal( end_n -= num_steps * MASK_BLOCK_N num_steps = tl.cdiv(end_n, BLOCK_N) start_n = max(end_n - num_steps * BLOCK_N, 0) - dq = _bwd_fused_atomics_dq_inner( + dq = _bwd_dq_inner_split( dq, q, k_ptr_adj, @@ -1647,7 +1570,7 @@ def _bwd_kernel_fused_atomics_dq_causal( @triton.jit -def _bwd_kernel_fused_atomics_dkdvdq_noncausal( +def _bwd_kernel_fused_atomic_noncausal( Q, K, V, @@ -1806,7 +1729,7 @@ def _bwd_kernel_fused_atomics_dkdvdq_noncausal( start_m = 0 num_steps = tl.cdiv(seqlen_q, BLOCK_M) - dk, dv = _bwd_fused_atomics_dkdvdq_inner( + dk, dv = _bwd_dkdvdq_inner_atomic( dk, dv, Q_ptr, @@ -1862,7 +1785,7 @@ def _bwd_kernel_fused_atomics_dkdvdq_noncausal( @triton.jit -def _bwd_kernel_fused_atomics_dkdv_noncausal( +def _bwd_kernel_split_dkdv_noncausal( Q, K, V, @@ -2004,7 +1927,7 @@ def _bwd_kernel_fused_atomics_dkdv_noncausal( start_m = 0 num_steps = tl.cdiv(seqlen_q, BLOCK_M) - dk, dv = _bwd_fused_atomics_dkdv_inner( + dk, dv = _bwd_dkdv_inner_split( dk, dv, Q_ptr, @@ -2056,7 +1979,7 @@ def _bwd_kernel_fused_atomics_dkdv_noncausal( @triton.jit -def _bwd_kernel_fused_atomics_dq_noncausal( +def _bwd_kernel_split_dq_noncausal( Q, K, V, @@ -2190,7 +2113,7 @@ def _bwd_kernel_fused_atomics_dq_noncausal( end_n = seqlen_k num_steps = tl.cdiv(seqlen_k, BLOCK_N) dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) - dq = _bwd_fused_atomics_dq_inner( + dq = _bwd_dq_inner_split( dq, q, K, @@ -2674,7 +2597,7 @@ def _bwd_dq_inner( use_cuda_graph=True, ) @triton.jit -def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), batch) +def bwd_kernel_fused_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), batch) Q, K, V, @@ -3253,7 +3176,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), b use_cuda_graph=True, ) @triton.jit -def bwd_kernel_noncausal( +def bwd_kernel_fused_noncausal( Q, K, V, @@ -3649,7 +3572,8 @@ def is_contiguous(x, name): DEBUG_TRITON_DETAIL: bool = False -def attention_backward_triton_split_fused_no_atomics_impl( +def attention_backward_triton_impl( + *, do: torch.Tensor, q: torch.Tensor, k: torch.Tensor, @@ -3667,22 +3591,14 @@ def attention_backward_triton_split_fused_no_atomics_impl( cu_seqlens_k: Optional[torch.Tensor], max_seqlen_q: Optional[int], max_seqlen_k: Optional[int], - dropout_p: float, - philox_seed: Optional[int], - philox_offset: Optional[int], - use_exp2: bool, - # fp8 - descale_q: Optional[torch.Tensor], - descale_k: Optional[torch.Tensor], - descale_v: Optional[torch.Tensor], - descale_o: Optional[torch.Tensor], - descale_dq: Optional[torch.Tensor], - descale_dk: Optional[torch.Tensor], - descale_dv: Optional[torch.Tensor], - # seqused for FA v3 seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, -): + dropout_p: float = 0.0, + philox_seed: Optional[int] = None, + philox_offset: Optional[int] = None, + use_exp2: bool = True, + mode: Literal["fused", "fused_atomic", "split"] = "fused", +) -> torch.Tensor: # get params, strides and shape IS_VARLEN = layout == "thd" use_dropout = dropout_p > 0.0 @@ -3894,27 +3810,25 @@ def attention_backward_triton_split_fused_no_atomics_impl( IS_FP8 = is_fp8([q, k, v]) if IS_FP8: FP8_MAX = torch.finfo(q.dtype).max - - # Check and create default descale tensors if not provided (for inputs) - if (descale_q is None) or (descale_k is None) or (descale_v is None): - warnings.warn( - "FP8 tensors detected but descale factors not provided. Using default scale of 1.0.", - UserWarning, - ) - # Create default descale tensors if not provided - # For GQA/MQA, q_descale should be shaped (batch, nheads_k) to match forward pass - if descale_q is None: - descale_q = torch.ones( - batch, nheads_k, dtype=torch.float32, device=q.device - ) - if descale_k is None: - descale_k = torch.ones( - batch, nheads_k, dtype=torch.float32, device=q.device - ) - if descale_v is None: - descale_v = torch.ones( - batch, nheads_k, dtype=torch.float32, device=q.device - ) + + warnings.warn( + "FP8 tensors detected in backward pass. Backward pass supports FP8 inputs but " + "descaling factors will default to 1.0.", + UserWarning, + ) + + # For GQA/MQA, q_descale should be shaped (batch, nheads_k) to match forward pass + descale_q = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + + descale_k = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + + descale_v = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None @@ -3924,6 +3838,7 @@ def attention_backward_triton_split_fused_no_atomics_impl( print(f"FP8 path triggered in bwd.py") else: FP8_MAX = None + descale_q = descale_k = descale_v = None stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = None # alibi setup @@ -4016,770 +3931,625 @@ def attention_backward_triton_split_fused_no_atomics_impl( dropout_mask.stride() ) - seqlen = max(max_seqlen_q, max_seqlen_k) - grid = lambda META: ( - nheads_k, - (seqlen + META["BLOCK_N1"] - 1) // META["BLOCK_N1"], - batch, - ) - if causal: - if DEBUG_TRITON: - print(f"bwd_kernel: grid = {grid}") # noqa: E701 - bwd_kernel_causal[grid]( - q, - k, - v, - sm_scale, - do, - dq, - dk, - dv, - softmax_lse, - delta, - stride_qb, - stride_qh, - stride_qm, - stride_qd, - stride_kb, - stride_kh, - stride_kn, - stride_kd, - stride_vb, - stride_vh, - stride_vn, - stride_vd, - stride_dqb, - stride_dqh, - stride_dqm, - stride_dqd, - stride_dkb, - stride_dkh, - stride_dkn, - stride_dkd, - stride_dvb, - stride_dvh, - stride_dvn, - stride_dvd, - stride_lse_b, - stride_lse_h, - stride_lse_m, - stride_delta_b, - stride_delta_h, - stride_delta_m, - stride_dob, - stride_doh, - stride_dom, - stride_dod, - stride_dropoutb, - stride_dropouth, - stride_dropoutm, - stride_dropoutn, - stride_descale_q_z, - stride_descale_k_z, - stride_descale_v_z, - stride_az, - stride_ah, - nheads_q, - nheads_k, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, # Pass seqused tensors - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - descale_q, - descale_k, - descale_v, - HEAD_DIM_QK=HEAD_DIM_QK, - HEAD_DIM_V=HEAD_DIM_V, - ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, - ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - USE_SEQUSED=( - seqused_q is not None or seqused_k is not None - ), # Add flag for seqused - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - else: - bwd_kernel_noncausal[grid]( - q, - k, - v, - sm_scale, - do, - dq, - dk, - dv, - softmax_lse, - delta, - stride_qb, - stride_qh, - stride_qm, - stride_qd, - stride_kb, - stride_kh, - stride_kn, - stride_kd, - stride_vb, - stride_vh, - stride_vn, - stride_vd, - stride_dqb, - stride_dqh, - stride_dqm, - stride_dqd, - stride_dkb, - stride_dkh, - stride_dkn, - stride_dkd, - stride_dvb, - stride_dvh, - stride_dvn, - stride_dvd, - stride_lse_b, - stride_lse_h, - stride_lse_m, - stride_delta_b, - stride_delta_h, - stride_delta_m, - stride_dob, - stride_doh, - stride_dom, - stride_dod, - stride_dropoutb, - stride_dropouth, - stride_dropoutm, - stride_dropoutn, - stride_descale_q_z, - stride_descale_k_z, - stride_descale_v_z, - stride_az, - stride_ah, - nheads_q, + # Choose which kernels to call based on mode + if mode == "fused": + seqlen = max(max_seqlen_q, max_seqlen_k) + grid = lambda META: ( nheads_k, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, # Pass seqused tensors - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - alibi_slopes, - descale_q, - descale_k, - descale_v, - HEAD_DIM_QK=HEAD_DIM_QK, - HEAD_DIM_V=HEAD_DIM_V, - ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, - ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - USE_ALIBI=use_alibi, - USE_EXP2=use_exp2, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - USE_SEQUSED=( - seqused_q is not None or seqused_k is not None - ), # Add flag for seqused - DEBUG_TRITON=DEBUG_TRITON, - DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, - ) - - return delta - - -def attention_backward_triton_fused_atomics_impl( - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float, - philox_seed: Optional[int] = 0, - philox_offset: Optional[int] = 0, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - fused: bool = False, - # seqused for FA v3 (currently ignored in this implementation) - seqused_q: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, -): - IS_FP8 = is_fp8([q, k, v]) - if IS_FP8: - FP8_MAX = torch.finfo(q.dtype).max - - # Check and create default descale tensors if not provided - if (descale_q is None) or (descale_k is None) or (descale_v is None) or (descale_do is None): - warnings.warn( - "FP8 tensors detected but descale factors not provided. Using default scale of 1.0.", - UserWarning, - ) - # Determine batch size for creating default descale tensors - if cu_seqlens_q is not None: - batch = len(cu_seqlens_q) - 1 - else: - batch = q.shape[0] - - nheads_q = q.shape[1] if cu_seqlens_q is not None else q.shape[2] - nheads_k = k.shape[1] if cu_seqlens_q is not None else k.shape[2] - - # Create default descale tensors if not provided - if descale_q is None: - descale_q = torch.ones( - batch, nheads_q, dtype=torch.float32, device=q.device - ) - if descale_k is None: - descale_k = torch.ones( - batch, nheads_k, dtype=torch.float32, device=q.device - ) - if descale_v is None: - descale_v = torch.ones( - batch, nheads_k, dtype=torch.float32, device=q.device - ) - if descale_do is None: - descale_do = torch.ones( - batch, nheads_q, dtype=torch.float32, device=q.device - ) - - descale_strides = ( - descale_q.stride(0), - descale_k.stride(0), - descale_v.stride(0), - descale_do.stride(0), - ) - - if DEBUG: - print(f"FP8 path triggered in bwd.py (fused_atomics)") - else: - FP8_MAX = None - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = None - descale_strides = ( - stride_descale_q_z, - stride_descale_k_z, - stride_descale_v_z, - ) - - IS_VARLEN = True if cu_seqlens_q is not None else False - - # get strides and shape - if IS_VARLEN: - # Layout for q,k,v is thd ie [total tokens, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = ( - len(cu_seqlens_q) - 1, - max_seqlen_q, - q.shape[1], - q.shape[2], - ) - seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - dq_strides = (0, dq.stride(1), dq.stride(0), dq.stride(2)) - dk_strides = (0, dk.stride(1), dk.stride(0), dk.stride(2)) - dv_strides = (0, dv.stride(1), dv.stride(0), dv.stride(2)) - do_strides = (0, do.stride(1), do.stride(0), do.stride(2)) - else: - # Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = q.shape - seqlen_k, num_k_heads = k.shape[1], k.shape[2] - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) - dq_strides = (dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3)) - dk_strides = (dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3)) - dv_strides = (dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3)) - do_strides = (do.stride(0), do.stride(2), do.stride(1), do.stride(3)) - - # BLOCK_D_MODEL, BLOCK_D_MODEL_POW2 - # padding for head_dim. Power of 2 or 16 - BLOCK_D_MODEL_POW2 = triton.next_power_of_2(head_sz) - BLOCK_D_MODEL_POW2 = max(BLOCK_D_MODEL_POW2, 16) - - # Configs - # PRE_BLOCK, BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 - # BLK_SLICE_FACTOR - NUM_WARPS, NUM_STAGES = 4, 1 - WAVES_PER_EU = 1 - PRE_BLOCK = 128 - # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 - BLK_SLICE_FACTOR = 2 - - # init delta - delta = torch.zeros_like(softmax_lse) - if IS_VARLEN: - # [total_tokens, num_q_heads, seqlen_q] - delta_strides = (0, delta.stride(1), delta.stride(0)) - else: - # [batch, num_q_heads, seqlen_q] - delta_strides = delta.stride() - - # preprocess - # compute D(delta) = rowsum(dO*O). Note, multiplication is element-wise. - pre_grid = (triton.cdiv(max_seqlen_q, PRE_BLOCK), batch, num_q_heads) - _bwd_fused_atomics_preprocess[pre_grid]( - o, - do, - delta, - *o_strides, - *delta_strides, - descale_strides[3], - cu_seqlens_q, - max_seqlen_q, - BLOCK_M=PRE_BLOCK, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - ) - - # dropout_mask - use_dropout = dropout_p > 0.0 - if use_dropout: - dropout_mask = torch.zeros( - (batch, num_q_heads, max_seqlen_q, max_seqlen_k), - device=q.device, - dtype=torch.float32, + (seqlen + META["BLOCK_N1"] - 1) // META["BLOCK_N1"], + batch, ) - dropout_strides = dropout_mask.stride() - else: - dropout_mask = None - dropout_strides = (0, 0, 0, 0) - - grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, num_k_heads) - grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, num_k_heads) - - if ( - fused - ): # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups - - BLOCK_N = ( - 128 if BLOCK_D_MODEL_POW2 < 160 else 64 - ) # larger head sizes lead to oom - config = { - "BLOCK_M": 32, - "BLOCK_N": BLOCK_N, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 1, - "BLK_SLICE_FACTOR": 2, - } - - num_k_pids = (max_seqlen_k + BLOCK_N - 1) // BLOCK_N - grid_dkdvdq = (batch * num_k_heads * num_k_pids,) - if causal: - _bwd_kernel_fused_atomics_dkdvdq_causal[grid_dkdvdq]( + if DEBUG_TRITON: + print(f"bwd_kernel: grid = {grid}") # noqa: E701 + bwd_kernel_fused_causal[grid]( q, k, v, sm_scale, do, + dq, dk, dv, - dq, softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_az, + stride_ah, + nheads_q, + nheads_k, cu_seqlens_q, cu_seqlens_k, + seqused_q, + seqused_k, # Pass seqused tensors max_seqlen_q, max_seqlen_k, dropout_mask, dropout_p, philox_seed, philox_offset, + alibi_slopes, descale_q, descale_k, descale_v, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BATCH=batch, - NUM_K_PIDS=num_k_pids, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, - **config, + USE_SEQUSED=( + seqused_q is not None or seqused_k is not None + ), # Add flag for seqused + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) else: - _bwd_kernel_fused_atomics_dkdvdq_noncausal[grid_dkdvdq]( + bwd_kernel_fused_noncausal[grid]( q, k, v, sm_scale, do, + dq, dk, dv, - dq, softmax_lse, delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_az, + stride_ah, + nheads_q, + nheads_k, cu_seqlens_q, cu_seqlens_k, + seqused_q, + seqused_k, # Pass seqused tensors max_seqlen_q, max_seqlen_k, dropout_mask, dropout_p, philox_seed, philox_offset, + alibi_slopes, descale_q, descale_k, descale_v, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BATCH=batch, - NUM_K_PIDS=num_k_pids, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, - **config, + USE_SEQUSED=( + seqused_q is not None or seqused_k is not None + ), # Add flag for seqused + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) + elif mode == "fused_atomic": + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 + BLK_SLICE_FACTOR = 2 + BLOCK_D_MODEL_POW2 = max(triton.next_power_of_2(HEAD_DIM_QK), 16) + + grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) + grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) + + # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups + BLOCK_N = ( + 128 if BLOCK_D_MODEL_POW2 < 160 else 64 + ) # larger head sizes lead to oom + config = { + "BLOCK_M": 32, + "BLOCK_N": BLOCK_N, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 1, + "BLK_SLICE_FACTOR": 2, + } - return delta - - # split kernels solution: one kernel computes dk, dv and the other computes dq - - if causal: - _bwd_kernel_fused_atomics_dkdv_causal[grid_dkdv]( - q, - k, - v, - sm_scale, - do, - dk, - dv, - softmax_lse, - delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - descale_q, - descale_k, - descale_v, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M1, - BLOCK_N=BLOCK_N1, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - _bwd_kernel_fused_atomics_dq_causal[grid_dq]( - q, - k, - v, - sm_scale, - do, - dq, - softmax_lse, - delta, - *q_strides, - *k_strides, - *v_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - descale_q, - descale_k, - descale_v, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M2, - BLOCK_N=BLOCK_N2, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - else: - _bwd_kernel_fused_atomics_dkdv_noncausal[grid_dkdv]( - q, - k, - v, - sm_scale, - do, - dk, - dv, - softmax_lse, - delta, - *q_strides, - *k_strides, - *v_strides, - *dk_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - descale_q, - descale_k, - descale_v, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M1, - BLOCK_N=BLOCK_N1, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - - _bwd_kernel_fused_atomics_dq_noncausal[grid_dq]( - q, - k, - v, - sm_scale, - do, - dq, - softmax_lse, - delta, - *q_strides, - *k_strides, - *v_strides, - *dq_strides, - *delta_strides, - *do_strides, - *dropout_strides, - *descale_strides, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_mask, - dropout_p, - philox_seed, - philox_offset, - descale_q, - descale_k, - descale_v, - NUM_Q_HEADS=num_q_heads, - NUM_K_HEADS=num_k_heads, - BLOCK_M=BLOCK_M2, - BLOCK_N=BLOCK_N2, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_D_MODEL=head_sz, - BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, - ENABLE_DROPOUT=use_dropout, - IS_VARLEN=IS_VARLEN, - IS_FP8=IS_FP8, - FP8_MAX=FP8_MAX, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, - waves_per_eu=WAVES_PER_EU, - ) - - return delta - + num_k_pids = (max_seqlen_k + BLOCK_N - 1) // BLOCK_N + grid_dkdvdq = (batch * nheads_k * num_k_pids,) -def attention_backward_triton_impl( - *, - do: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - softmax_lse: torch.Tensor, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor, - sm_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - layout: str, - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], - max_seqlen_q: Optional[int], - max_seqlen_k: Optional[int], - seqused_q: Optional[torch.Tensor] = None, - seqused_k: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - philox_seed: Optional[int] = None, - philox_offset: Optional[int] = None, - use_exp2: bool = True, - mode: str = "fused_no_atomics", -) -> torch.Tensor: - """Unified backward interface dispatching to atomics or no-atomics implementation. - - Parameters mirror the superset of the two legacy interfaces. The public API should - call ONLY this function going forward. - mode: 'fused_atomics' or 'fused_no_atomics'; layout: 'bshd' or 'thd'; use_exp2 retained for parity. - """ - # Allow FP8 dtypes and handle gradient tensor dtype casting - dq_original, dk_original, dv_original = None, None, None - do_original = None - - if is_fp8([q, k, v]): - warnings.warn( - "FP8 tensors detected in backward pass. Backward pass supports FP8 inputs but " - "descaling factors will default to 1.0.", - UserWarning, - ) + if causal: + _bwd_kernel_fused_atomic_causal[grid_dkdvdq]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + **config, + ) + else: + _bwd_kernel_fused_atomic_noncausal[grid_dkdvdq]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + **config, + ) + elif mode == "split": + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 + BLK_SLICE_FACTOR = 2 + BLOCK_D_MODEL_POW2 = max(triton.next_power_of_2(HEAD_DIM_QK), 16) + + grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) + grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) + + if causal: + _bwd_kernel_split_dkdv_causal[grid_dkdv]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + _bwd_kernel_split_dq_causal[grid_dq]( + q, + k, + v, + sm_scale, + do, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + else: + _bwd_kernel_split_dkdv_noncausal[grid_dkdv]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) - if mode == "fused_atomics": - delta = attention_backward_triton_fused_atomics_impl( - do, - q, - k, - v, - o, - softmax_lse, - dq, - dk, - dv, - sm_scale, - alibi_slopes, - causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q if max_seqlen_q is not None else q.shape[1], - max_seqlen_k if max_seqlen_k is not None else k.shape[1], - dropout_p, - philox_seed or 0, - philox_offset or 0, - None, - None, - None, - None, - True, # fused flag - None, - None, - ) - elif mode == "fused_no_atomics": - delta = attention_backward_triton_split_fused_no_atomics_impl( - do, - q, - k, - v, - o, - softmax_lse, - dq, - dk, - dv, - sm_scale, - alibi_slopes, - causal, - layout, # layout required here - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - philox_seed, - philox_offset, - use_exp2, - None, - None, - None, - None, - None, - None, - None, - seqused_q, - seqused_k, - ) + _bwd_kernel_split_dq_noncausal[grid_dq]( + q, + k, + v, + sm_scale, + do, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) else: raise ValueError( - f"Unknown backward mode '{mode}'. Expected 'fused_atomics' or 'fused_no_atomics'." + f"Unknown backward mode '{mode}'. Expected 'split', 'fused_atomic' or 'fused'." ) + return delta diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 786fc7a2039..74e9729d465 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -40,15 +40,6 @@ def get_fwd_configs(autotune: bool): "HK", ] - # fallback config - if False: - configs.append(triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - )) - return configs, keys - # get best config for the architecture if not autotune: arch = get_arch() diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 03e27e33660..5a42c89684a 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -50,7 +50,7 @@ ) if USE_TRITON_ROCM: # TODO remove this random.seed(42) -BWD_MODE = os.environ.get("BWD_MODE", "fused_no_atomics").lower() +BWD_MODE = os.environ.get("BWD_MODE", "fused").lower() DROPOUT_USE_PYTORCH = False DROPOUT_DUMP = False USE_EXP2 = True From 948df0146aeba1e3ebcdea5f90c1fe1e46af5088 Mon Sep 17 00:00:00 2001 From: Michael Date: Thu, 9 Oct 2025 19:19:31 -0500 Subject: [PATCH 19/33] v3 passing --- flash_attn/flash_attn_triton_amd/bwd.py | 21 +- .../flash_attn_triton_amd/fwd_decode.py | 21 +- .../flash_attn_triton_amd/fwd_prefill.py | 37 ++-- .../flash_attn_triton_amd/interface_v2.py | 179 +++++++++++++++--- .../flash_attn_triton_amd/interface_v3.py | 76 +++++++- flash_attn/flash_attn_triton_amd/utils.py | 1 + 6 files changed, 270 insertions(+), 65 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 8d2c1db5982..a6f96e4011b 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -3583,6 +3583,7 @@ def attention_backward_triton_impl( dq: torch.Tensor, dk: torch.Tensor, dv: torch.Tensor, + delta: torch.Tensor, sm_scale: float, alibi_slopes: Optional[torch.Tensor], causal: bool, @@ -3598,7 +3599,7 @@ def attention_backward_triton_impl( philox_offset: Optional[int] = None, use_exp2: bool = True, mode: Literal["fused", "fused_atomic", "split"] = "fused", -) -> torch.Tensor: +): # get params, strides and shape IS_VARLEN = layout == "thd" use_dropout = dropout_p > 0.0 @@ -3856,11 +3857,14 @@ def attention_backward_triton_impl( ACTUAL_HEAD_DIM_QK = head_size_qk ACTUAL_HEAD_DIM_V = head_size_v - # init delta + # Validate pre-allocated delta tensor if IS_VARLEN: # Shape expected by interface varlen backward: (Hq, Total_Q) total_q, _, _ = q.shape - delta = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + assert delta.shape[0] == nheads_q, f"delta.shape[0] ({delta.shape[0]}) must equal nheads_q ({nheads_q})" + assert delta.shape[1] >= total_q, f"delta.shape[1] ({delta.shape[1]}) must be >= total_q ({total_q})" + assert delta.dtype == torch.float32, f"delta must be float32, got {delta.dtype}" + assert delta.device == q.device, f"delta must be on same device as q" stride_delta_b, stride_delta_h, stride_delta_m = ( 0, delta.stride(0), @@ -3869,9 +3873,11 @@ def attention_backward_triton_impl( else: # Shape expected by dense backward: (B, Hq, Sq) seqlen_q = q.shape[1] - delta = torch.zeros( - (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 - ) + assert delta.shape[0] == batch, f"delta.shape[0] ({delta.shape[0]}) must equal batch ({batch})" + assert delta.shape[1] == nheads_q, f"delta.shape[1] ({delta.shape[1]}) must equal nheads_q ({nheads_q})" + assert delta.shape[2] >= seqlen_q, f"delta.shape[2] ({delta.shape[2]}) must be >= seqlen_q ({seqlen_q})" + assert delta.dtype == torch.float32, f"delta must be float32, got {delta.dtype}" + assert delta.device == q.device, f"delta must be on same device as q" stride_delta_b, stride_delta_h, stride_delta_m = delta.stride() pre_grid = lambda META: ( @@ -4550,6 +4556,3 @@ def attention_backward_triton_impl( raise ValueError( f"Unknown backward mode '{mode}'. Expected 'split', 'fused_atomic' or 'fused'." ) - - - return delta diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index 657f0b3ddbb..fb096de96db 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -846,6 +846,7 @@ def attention_forward_decode_triton_impl( k_new: Optional[torch.Tensor], v_new: Optional[torch.Tensor], out: torch.Tensor, + softmax_lse: torch.Tensor, sm_scale: float, causal: bool, window_size_left: int, @@ -1104,11 +1105,19 @@ def attention_forward_decode_triton_impl( dtype=torch.float32, device=q.device, ) - lse = torch.empty( - (batch_size * n_group_q * heads_per_group_q, seqlen_q), - dtype=torch.float32, - device=q.device, - ) + + # Validate pre-allocated softmax_lse tensor + # Expected shape after view: (batch_size, n_group_q * heads_per_group_q, seqlen_q) + # Internal shape: (batch_size * n_group_q * heads_per_group_q, seqlen_q) + expected_h_total = batch_size * n_group_q * heads_per_group_q + assert softmax_lse.shape[0] == batch_size, f"softmax_lse.shape[0] ({softmax_lse.shape[0]}) must equal batch_size ({batch_size})" + assert softmax_lse.shape[1] == n_group_q * heads_per_group_q, f"softmax_lse.shape[1] ({softmax_lse.shape[1]}) must equal n_group_q * heads_per_group_q ({n_group_q * heads_per_group_q})" + assert softmax_lse.shape[2] >= seqlen_q, f"softmax_lse.shape[2] ({softmax_lse.shape[2]}) must be >= seqlen_q ({seqlen_q})" + assert softmax_lse.dtype == torch.float32, f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert softmax_lse.device == q.device, f"softmax_lse must be on same device as q" + + # Create internal lse view for kernel use + lse = softmax_lse.view(expected_h_total, -1)[:, :seqlen_q].contiguous() # get intermediate tensor strides stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d = out_splitk.stride() @@ -1381,5 +1390,3 @@ def attention_forward_decode_triton_impl( PADDED_HEAD=is_padded_head, num_warps=num_warps_reduce, ) - - return lse.view(batch_size, n_group_q * heads_per_group_q, seqlen_q) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 74e9729d465..dc259e99675 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -1495,6 +1495,8 @@ def attention_forward_prefill_triton_impl( k: torch.Tensor, v: torch.Tensor, o: torch.Tensor, + softmax_lse: torch.Tensor, + sd_mask: Optional[torch.Tensor], sm_scale: float, alibi_slopes: Optional[torch.Tensor], causal: bool, @@ -1602,10 +1604,11 @@ def attention_forward_prefill_triton_impl( batch = len(cu_seqlens_q) - 1 head_size_qk = head_size_q - # softmax_lse shape - softmax_lse = torch.zeros( - (nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32 - ) + # Assert softmax_lse tensor is large enough + assert softmax_lse.shape[0] >= nheads_q, f"softmax_lse.shape[0]={softmax_lse.shape[0]} must be >= nheads_q={nheads_q}" + assert softmax_lse.shape[1] >= total_seqlen_q, f"softmax_lse.shape[1]={softmax_lse.shape[1]} must be >= total_seqlen_q={total_seqlen_q}" + assert softmax_lse.dtype == torch.float32, f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert softmax_lse.device == q.device, f"softmax_lse must be on same device as q" # strides stride_qb, stride_qh, stride_qm, stride_qd = ( @@ -1678,10 +1681,12 @@ def attention_forward_prefill_triton_impl( max_seqlens_q = seqlen_q max_seqlens_k = seqlen_k - # softmax_lse shape - softmax_lse = torch.zeros( - (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 - ) + # Assert softmax_lse tensor is large enough + assert softmax_lse.shape[0] >= batch, f"softmax_lse.shape[0]={softmax_lse.shape[0]} must be >= batch={batch}" + assert softmax_lse.shape[1] >= nheads_q, f"softmax_lse.shape[1]={softmax_lse.shape[1]} must be >= nheads_q={nheads_q}" + assert softmax_lse.shape[2] >= seqlen_q, f"softmax_lse.shape[2]={softmax_lse.shape[2]} must be >= seqlen_q={seqlen_q}" + assert softmax_lse.dtype == torch.float32, f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert softmax_lse.device == q.device, f"softmax_lse must be on same device as q" # strides stride_qb, stride_qh, stride_qm, stride_qd = ( @@ -1823,11 +1828,15 @@ def attention_forward_prefill_triton_impl( # only. This return holds no useful output aside from debugging. NEEDS_SDMASK = (dropout_p > 0.0) or return_softmax if NEEDS_SDMASK: - sd_mask = torch.zeros( - (batch, nheads_q, max_seqlens_q, max_seqlens_k), - device=q.device, - dtype=torch.float32, - ) + assert sd_mask is not None, "sd_mask must be provided when return_softmax=True or dropout_p > 0" + # Assert sd_mask tensor is large enough + assert sd_mask.shape[0] >= batch, f"sd_mask.shape[0]={sd_mask.shape[0]} must be >= batch={batch}" + assert sd_mask.shape[1] >= nheads_q, f"sd_mask.shape[1]={sd_mask.shape[1]} must be >= nheads_q={nheads_q}" + assert sd_mask.shape[2] >= max_seqlens_q, f"sd_mask.shape[2]={sd_mask.shape[2]} must be >= max_seqlens_q={max_seqlens_q}" + assert sd_mask.shape[3] >= max_seqlens_k, f"sd_mask.shape[3]={sd_mask.shape[3]} must be >= max_seqlens_k={max_seqlens_k}" + assert sd_mask.dtype == torch.float32, f"sd_mask must be float32, got {sd_mask.dtype}" + assert sd_mask.device == q.device, f"sd_mask must be on same device as q" + if DROPOUT_USE_PYTORCH: dropout_mask = create_dropout_mask( dropout_p, @@ -1940,5 +1949,3 @@ def attention_forward_prefill_triton_impl( FP8_P_DESCALE=False, USE_SEQUSED=(seqused_q is not None or seqused_k is not None), ) # Add flag for seqused - - return softmax_lse, sd_mask if return_softmax else None diff --git a/flash_attn/flash_attn_triton_amd/interface_v2.py b/flash_attn/flash_attn_triton_amd/interface_v2.py index 9df299666eb..71e1630fc2f 100644 --- a/flash_attn/flash_attn_triton_amd/interface_v2.py +++ b/flash_attn/flash_attn_triton_amd/interface_v2.py @@ -4,7 +4,7 @@ from .fwd_prefill import attention_forward_prefill_triton_impl from .fwd_decode import attention_forward_decode_triton_impl from .bwd import attention_backward_triton_impl -from .utils import DEBUG, USE_EXP2, BWD_MODE, PHILOX_SEED, PHILOX_OFFSET +from .utils import DEBUG, USE_EXP2, BWD_MODE, PHILOX_SEED, PHILOX_OFFSET, SHAPE_EXPECTATIONS, round_multiple def fwd( @@ -77,14 +77,48 @@ def fwd( nheads_k = k.shape[2] assert (nheads_q % nheads_k) == 0 + # Create output tensors based on shape expectations + if SHAPE_EXPECTATIONS == "rounded": + # Rounded shapes for NVIDIA compatibility + softmax_lse = torch.zeros( + (batch, nheads_q, round_multiple(max_seqlen_q, 128)), + device=q.device, + dtype=torch.float32, + ) + if return_softmax: + sd_mask = torch.zeros( + (batch, nheads_q, round_multiple(max_seqlen_q, 128), round_multiple(max_seqlen_k, 128)), + device=q.device, + dtype=torch.float32, + ) + else: + sd_mask = None + else: + # Exact shapes for AMD + softmax_lse = torch.zeros( + (batch, nheads_q, max_seqlen_q), + device=q.device, + dtype=torch.float32, + ) + if return_softmax: + sd_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32, + ) + else: + sd_mask = None + # call implementation if DEBUG: print("Using Triton implementation") - softmax_lse, sd_mask = attention_forward_prefill_triton_impl( + attention_forward_prefill_triton_impl( q, k, v, out, + softmax_lse, + sd_mask, softmax_scale, alibi_slopes, causal, @@ -104,6 +138,11 @@ def fwd( None, None, None, + None, + None, + None, + None, + None, ) if DEBUG: @@ -116,23 +155,37 @@ def fwd( # --- Assertions (shape + dtype contracts) --- # out: (B, Sq, Hq, D) assert out.shape == q.shape, f"[fwd] out shape {out.shape} != q shape {q.shape}" - # softmax_lse: (B, Hq, Sq) - expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) - assert ( - softmax_lse.shape == expected_lse_shape - ), f"[fwd] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + # softmax_lse dtype assert ( softmax_lse.dtype == torch.float32 ), f"[fwd] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + # softmax_lse shape depends on SHAPE_EXPECTATIONS + if SHAPE_EXPECTATIONS == "rounded": + expected_lse_shape = (q.shape[0], q.shape[2], round_multiple(q.shape[1], 128)) + else: + expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[fwd] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" if return_softmax: # sd_mask: (B, Hq, Sq, Sk) assert sd_mask is not None, "[fwd] return_softmax=True but sd_mask is None" assert sd_mask.dim() == 4, f"[fwd] sd_mask dim {sd_mask.dim()} != 4" - assert ( - sd_mask.shape[0] == q.shape[0] - and sd_mask.shape[1] == q.shape[2] - and sd_mask.shape[2] == q.shape[1] - ), f"[fwd] sd_mask leading dims {sd_mask.shape[:3]} mismatch (B,Hq,Sq) {(q.shape[0], q.shape[2], q.shape[1])}" + if SHAPE_EXPECTATIONS == "rounded": + expected_sq = round_multiple(q.shape[1], 128) + expected_sk = round_multiple(k.shape[1], 128) + assert ( + sd_mask.shape[0] == q.shape[0] + and sd_mask.shape[1] == q.shape[2] + and sd_mask.shape[2] == expected_sq + and sd_mask.shape[3] == expected_sk + ), f"[fwd] sd_mask shape {sd_mask.shape} != (B={q.shape[0]}, Hq={q.shape[2]}, Sq={expected_sq}, Sk={expected_sk})" + else: + assert ( + sd_mask.shape[0] == q.shape[0] + and sd_mask.shape[1] == q.shape[2] + and sd_mask.shape[2] == q.shape[1] + ), f"[fwd] sd_mask leading dims {sd_mask.shape[:3]} mismatch (B,Hq,Sq) {(q.shape[0], q.shape[2], q.shape[1])}" else: assert sd_mask is None, "[fwd] return_softmax=False but sd_mask is not None" @@ -193,7 +246,18 @@ def bwd( dv = torch.zeros_like(v) if dv is None else dv.zero_() # get shape - batch, _, nheads_q, _ = q.shape + batch, seqlen_q, nheads_q, _ = q.shape + + # Create delta tensor with shape based on expectations + # delta (softmax_d) : (B, Hq, Sq) or (B, Hq, round_multiple(Sq, 128)) + if SHAPE_EXPECTATIONS == "rounded": + delta = torch.zeros( + (batch, nheads_q, round_multiple(seqlen_q, 128)), + device=q.device, + dtype=torch.float32, + ) + else: + delta = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) # Upstream change: base seeding logic on provided rng_state instead of dropout probability. if rng_state is not None: @@ -212,7 +276,7 @@ def bwd( # call implementation if DEBUG: print(f"Using Triton implementation in {BWD_MODE} mode") - delta = attention_backward_triton_impl( + attention_backward_triton_impl( do=dout, q=q, k=k, @@ -222,13 +286,14 @@ def bwd( dq=dq, dk=dk, dv=dv, + delta=delta, sm_scale=softmax_scale, alibi_slopes=alibi_slopes, causal=causal, layout="bshd", cu_seqlens_q=None, cu_seqlens_k=None, - max_seqlen_q=q.shape[1], + max_seqlen_q=seqlen_q, max_seqlen_k=k.shape[1], seqused_q=None, seqused_k=None, @@ -249,7 +314,10 @@ def bwd( assert dk.shape == k.shape, f"[bwd] dk shape {dk.shape} != k shape {k.shape}" assert dv.shape == v.shape, f"[bwd] dv shape {dv.shape} != v shape {v.shape}" # delta (softmax_d) : (B, Hq, Sq) - expected_delta_shape = (q.shape[0], q.shape[2], q.shape[1]) + if SHAPE_EXPECTATIONS == "rounded": + expected_delta_shape = (q.shape[0], q.shape[2], round_multiple(q.shape[1], 128)) + else: + expected_delta_shape = (q.shape[0], q.shape[2], q.shape[1]) assert ( delta.shape == expected_delta_shape ), f"[bwd] delta shape {delta.shape} != {expected_delta_shape}" @@ -324,7 +392,33 @@ def varlen_fwd( # Layout and basic info for varlen layout = "thd" batch = len(cu_seqlens_q) - 1 - _, nheads_q, _ = q.shape + total_q, nheads_q, _ = q.shape + + # Create softmax_lse tensor - varlen always uses exact shape (Hq, Total_Q) + softmax_lse = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + + # Create sd_mask tensor if needed + if return_softmax: + # sd_mask: (B, Hq, Sq, Sk) - shape based on expectations + if SHAPE_EXPECTATIONS == "rounded": + sd_mask = torch.zeros( + ( + batch, + nheads_q, + round_multiple(max_seqlen_q, 128), + round_multiple(max_seqlen_k, 128), + ), + device=q.device, + dtype=q.dtype, + ) + else: + sd_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=q.dtype, + ) + else: + sd_mask = None if alibi_slopes is not None: if alibi_slopes.dim() == 1: @@ -346,11 +440,13 @@ def varlen_fwd( # call implementation if DEBUG: print("Using Triton implementation") - softmax_lse, sd_mask = attention_forward_prefill_triton_impl( + attention_forward_prefill_triton_impl( q, k, v, out, + softmax_lse, + sd_mask, softmax_scale, alibi_slopes, causal, @@ -396,12 +492,19 @@ def varlen_fwd( sd_mask is not None ), "[varlen_fwd] return_softmax=True but sd_mask is None" assert sd_mask.dim() == 4, f"[varlen_fwd] sd_mask dim {sd_mask.dim()} != 4" - assert sd_mask.shape[0] == ( - len(cu_seqlens_q) - 1 - ), f"[varlen_fwd] sd_mask batch {sd_mask.shape[0]} != {len(cu_seqlens_q)-1}" - assert ( - sd_mask.shape[1] == q.shape[1] - ), f"[varlen_fwd] sd_mask nheads {sd_mask.shape[1]} != {q.shape[1]}" + batch = len(cu_seqlens_q) - 1 + assert sd_mask.shape[0] == batch, f"[varlen_fwd] sd_mask batch {sd_mask.shape[0]} != {batch}" + assert sd_mask.shape[1] == q.shape[1], f"[varlen_fwd] sd_mask nheads {sd_mask.shape[1]} != {q.shape[1]}" + if SHAPE_EXPECTATIONS == "rounded": + expected_sq = round_multiple(max_seqlen_q, 128) + expected_sk = round_multiple(max_seqlen_k, 128) + assert ( + sd_mask.shape[2] == expected_sq and sd_mask.shape[3] == expected_sk + ), f"[varlen_fwd] sd_mask shape {sd_mask.shape} != (B={batch}, Hq={q.shape[1]}, Sq={expected_sq}, Sk={expected_sk})" + else: + assert ( + sd_mask.shape[2] == max_seqlen_q and sd_mask.shape[3] == max_seqlen_k + ), f"[varlen_fwd] sd_mask shape {sd_mask.shape} != (B={batch}, Hq={q.shape[1]}, Sq={max_seqlen_q}, Sk={max_seqlen_k})" else: assert ( sd_mask is None @@ -476,7 +579,16 @@ def varlen_bwd( # get shape batch = len(cu_seqlens_q) - 1 - _, nheads_q, _ = q.shape + total_q, nheads_q, _ = q.shape + + # Create delta tensor with shape based on expectations + # delta (softmax_d) : (Hq, Total_Q) or (Hq, Total_Q + 128*batch) + if SHAPE_EXPECTATIONS == "rounded": + delta = torch.zeros( + (nheads_q, total_q + 128 * batch), device=q.device, dtype=torch.float32 + ) + else: + delta = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) # Upstream change: base seeding logic on provided rng_state instead of dropout probability. if rng_state is not None: @@ -495,7 +607,7 @@ def varlen_bwd( # call implementation if DEBUG: print(f"Using Triton implementation in {BWD_MODE} mode") - delta = attention_backward_triton_impl( + attention_backward_triton_impl( do=dout, q=q, k=k, @@ -505,6 +617,7 @@ def varlen_bwd( dq=dq, dk=dk, dv=dv, + delta=delta, sm_scale=softmax_scale, alibi_slopes=alibi_slopes, causal=causal, @@ -532,7 +645,11 @@ def varlen_bwd( assert dq.shape == q.shape, f"[varlen_bwd] dq shape {dq.shape} != q shape {q.shape}" assert dk.shape == k.shape, f"[varlen_bwd] dk shape {dk.shape} != k shape {k.shape}" assert dv.shape == v.shape, f"[varlen_bwd] dv shape {dv.shape} != v shape {v.shape}" - expected_delta_shape = (q.shape[1], q.shape[0]) # (Hq, Total_Q) + if SHAPE_EXPECTATIONS == "rounded": + batch = len(cu_seqlens_q) - 1 + expected_delta_shape = (q.shape[1], q.shape[0] + 128 * batch) + else: + expected_delta_shape = (q.shape[1], q.shape[0]) # (Hq, Total_Q) assert ( delta.shape == expected_delta_shape ), f"[varlen_bwd] delta shape {delta.shape} != {expected_delta_shape}" @@ -622,7 +739,10 @@ def fwd_kvcache( v_new = v # get shape - batch, _, nheads_q, _ = q.shape + batch, seqlen_q, nheads_q, _ = q.shape + + # Create softmax_lse tensor - decode always uses exact shape (B, Hq, Sq) + softmax_lse = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) if alibi_slopes is not None: if alibi_slopes.dim() == 1: @@ -633,13 +753,14 @@ def fwd_kvcache( # launch kernel if DEBUG: print("Using Triton implementation") - softmax_lse = attention_forward_decode_triton_impl( + attention_forward_decode_triton_impl( q, k_cache, v_cache, k_new, v_new, out, + softmax_lse, softmax_scale, causal, window_left, diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index e7281373aac..207c9ed2e52 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -292,13 +292,18 @@ def fwd( f"Using Decode Triton implementation (cache_seqlens={seqused_k is not None}, k_new={k_new is not None}, v_new={v_new is not None}, kv_batch_idx={kv_batch_idx is not None})" ) - softmax_lse = attention_forward_decode_triton_impl( + # Create softmax_lse tensor for decode - always exact shape (B, Hq, Sq) + batch, seqlen_q, nheads_q, _ = q.shape + softmax_lse = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) + + attention_forward_decode_triton_impl( q, k, v, k_new, v_new, out, + softmax_lse, softmax_scale, causal_flag, window_size_left, @@ -319,11 +324,27 @@ def fwd( else: if DEBUG: print("Using Prefill Triton implementation") - softmax_lse, _ = attention_forward_prefill_triton_impl( + + # Create softmax_lse tensor - FA3 always uses exact shapes + if layout == "thd": + # varlen: (Hq, Total_Q) + total_q, nheads_q, _ = q.shape + softmax_lse = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + else: + # bshd: (B, Hq, Sq) + batch, seqlen_q, nheads_q, _ = q.shape + softmax_lse = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) + + # sd_mask is not returned in v3 interface + sd_mask = None + + attention_forward_prefill_triton_impl( q, k, v, out, + softmax_lse, + sd_mask, softmax_scale, alibi_slopes, causal_flag, @@ -356,6 +377,31 @@ def fwd( print("out:", out.dtype if out is not None else None, out.shape if out is not None else None) print("softmax_lse:", softmax_lse.dtype if softmax_lse is not None else None, softmax_lse.shape if softmax_lse is not None else None) + # --- Assertions (FA3 always expects exact shapes) --- + # out: same shape as q except last dim is v's head_dim + if layout == "thd": + # varlen: (Total_Q, Hq, Dv) + assert out.shape[0] == q.shape[0], f"[fwd_v3] out.shape[0] {out.shape[0]} != q.shape[0] {q.shape[0]}" + assert out.shape[1] == q.shape[1], f"[fwd_v3] out.shape[1] {out.shape[1]} != q.shape[1] {q.shape[1]}" + assert out.shape[2] == v.shape[-1], f"[fwd_v3] out.shape[2] {out.shape[2]} != v.shape[-1] {v.shape[-1]}" + else: + # bshd: (B, Sq, Hq, Dv) + assert out.shape[0] == q.shape[0], f"[fwd_v3] out.shape[0] {out.shape[0]} != q.shape[0] {q.shape[0]}" + assert out.shape[1] == q.shape[1], f"[fwd_v3] out.shape[1] {out.shape[1]} != q.shape[1] {q.shape[1]}" + assert out.shape[2] == q.shape[2], f"[fwd_v3] out.shape[2] {out.shape[2]} != q.shape[2] {q.shape[2]}" + assert out.shape[3] == v.shape[-1], f"[fwd_v3] out.shape[3] {out.shape[3]} != v.shape[-1] {v.shape[-1]}" + + # softmax_lse dtype + assert softmax_lse.dtype == torch.float32, f"[fwd_v3] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + # softmax_lse shape depends on layout + if layout == "thd": + # varlen: (Hq, Total_Q) + expected_lse_shape = (q.shape[1], q.shape[0]) + else: + # bshd: (B, Hq, Sq) + expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert softmax_lse.shape == expected_lse_shape, f"[fwd_v3] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + # Return format compatible with v3 # V3 returns (out, softmax_lse, *rest) where rest can be empty or contain additional outputs return out, softmax_lse @@ -455,13 +501,17 @@ def bwd( # Variable length sequence mode layout = "thd" batch = len(cu_seqlens_q) - 1 - _, nheads_q, _ = q.shape + total_q, nheads_q, _ = q.shape + # Create delta tensor - varlen: (Hq, Total_Q) + delta = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) else: # Regular batch mode layout = "bshd" - batch, _, nheads_q, _ = q.shape + batch, seqlen_q, nheads_q, _ = q.shape max_seqlen_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q max_seqlen_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k + # Create delta tensor - bshd: (B, Hq, Sq) + delta = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) # V3 backward doesn't have dropout or alibi slopes dropout_p = 0.0 @@ -471,7 +521,7 @@ def bwd( # Call implementation if DEBUG: print(f"Using Triton implementation in {BWD_MODE} mode") - delta = attention_backward_triton_impl( + attention_backward_triton_impl( do=dout, q=q, k=k, @@ -481,6 +531,7 @@ def bwd( dq=dq, dk=dk, dv=dv, + delta=delta, sm_scale=softmax_scale, alibi_slopes=alibi_slopes, causal=causal, @@ -505,6 +556,21 @@ def bwd( print("dv:", dv.dtype if dv is not None else None, dv.shape if dv is not None else None) print("delta:", delta.dtype if delta is not None else None, delta.shape if delta is not None else None) + # --- Assertions (FA3 always expects exact shapes) --- + # Gradients should match input shapes + assert dq.shape == q.shape, f"[bwd_v3] dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"[bwd_v3] dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"[bwd_v3] dv shape {dv.shape} != v shape {v.shape}" + # delta (softmax_d) should match softmax_lse shape + assert delta.dtype == torch.float32, f"[bwd_v3] delta dtype {delta.dtype} != torch.float32" + if layout == "thd": + # varlen: (Hq, Total_Q) + expected_delta_shape = (q.shape[1], q.shape[0]) + else: + # bshd: (B, Hq, Sq) + expected_delta_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert delta.shape == expected_delta_shape, f"[bwd_v3] delta shape {delta.shape} != {expected_delta_shape}" + # V3 expects (dq, dk, dv, softmax_d, *rest) # delta is the softmax_d in this case return dq, dk, dv, delta diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 5a42c89684a..7c7a7f2d9f1 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -56,6 +56,7 @@ USE_EXP2 = True PHILOX_SEED = 0x1BF58 PHILOX_OFFSET = 0x1D4B49 +SHAPE_EXPECTATIONS: Literal["exact", "rounded"] = "exact" # ------------------------------- From cc4cbf95ed0d8b3790dbc22a6641b454d860bcdc Mon Sep 17 00:00:00 2001 From: Michael Date: Thu, 9 Oct 2025 23:16:33 -0500 Subject: [PATCH 20/33] tune more --- flash_attn/flash_attn_triton_amd/bwd.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index a6f96e4011b..195d6d47975 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -63,12 +63,16 @@ def get_bwd_configs(autotune: bool): else: preprocess_autotune_configs = [ triton.Config({"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8), + triton.Config({"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=4), ] noncausal_autotune_configs = [ triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_M1": 64, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 2, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), ] causal_autotune_configs = [ triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), ] else: preprocess_autotune_configs = [ @@ -180,6 +184,7 @@ def get_bwd_configs(autotune: bool): (causal_autotune_configs, causal_autotune_keys), \ (noncausal_autotune_configs, noncausal_autotune_keys) +# os.environ["TRITON_PRINT_AUTOTUNING"] = "1" ( (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), From f5f67c989b8a7c0678ba4398e45763c879b6e46c Mon Sep 17 00:00:00 2001 From: Michael Date: Fri, 10 Oct 2025 09:14:18 -0500 Subject: [PATCH 21/33] improve perf --- flash_attn/flash_attn_triton_amd/bwd.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 195d6d47975..bfdc7120f87 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -54,6 +54,8 @@ def get_bwd_configs(autotune: bool): noncausal_autotune_configs = [ triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), triton.Config({"BLOCK_M1": 64, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 2, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=8), + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=8), ] causal_autotune_configs = [ triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), From 91abd990db3de620da4d8fa71ecbf9bdad62b828 Mon Sep 17 00:00:00 2001 From: Michael Date: Fri, 10 Oct 2025 09:32:19 -0500 Subject: [PATCH 22/33] clean up --- .../flash_attn_triton_amd/fwd_prefill.py | 6 + .../flash_attn_triton_amd/interface_v2.py | 3 - flash_attn/flash_attn_triton_amd/utils.py | 296 +----------------- 3 files changed, 7 insertions(+), 298 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index dc259e99675..506cfba02f9 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -73,6 +73,12 @@ def get_fwd_configs(autotune: bool): num_stages=1, num_warps=4, )) + elif arch in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"): # RDNA architectures + configs.append(triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=2, + )) else: configs.append(triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, diff --git a/flash_attn/flash_attn_triton_amd/interface_v2.py b/flash_attn/flash_attn_triton_amd/interface_v2.py index 71e1630fc2f..48f0ed45f8b 100644 --- a/flash_attn/flash_attn_triton_amd/interface_v2.py +++ b/flash_attn/flash_attn_triton_amd/interface_v2.py @@ -79,7 +79,6 @@ def fwd( # Create output tensors based on shape expectations if SHAPE_EXPECTATIONS == "rounded": - # Rounded shapes for NVIDIA compatibility softmax_lse = torch.zeros( (batch, nheads_q, round_multiple(max_seqlen_q, 128)), device=q.device, @@ -94,7 +93,6 @@ def fwd( else: sd_mask = None else: - # Exact shapes for AMD softmax_lse = torch.zeros( (batch, nheads_q, max_seqlen_q), device=q.device, @@ -142,7 +140,6 @@ def fwd( None, None, None, - None, ) if DEBUG: diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 7c7a7f2d9f1..1f8fd86d1f7 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -50,7 +50,7 @@ ) if USE_TRITON_ROCM: # TODO remove this random.seed(42) -BWD_MODE = os.environ.get("BWD_MODE", "fused").lower() +BWD_MODE: Literal["fused", "fused_atomic", "split"] = "fused" DROPOUT_USE_PYTORCH = False DROPOUT_DUMP = False USE_EXP2 = True @@ -555,300 +555,6 @@ def generate_varlen_kv_packed( x.requires_grad_() return x, cu_seqlens, max_seqlen - -def input_helper( - BATCH: int, - HQ: int, - HK: int, - N_CTX_Q: int, - N_CTX_K: int, - D_HEAD: int, - CAUSAL: bool, - DROPOUT_P: float, - dtype: torch.dtype, - layout: Literal["bshd", "bhsd", "thd"], - packing: Optional[Literal["kv", "qkv"]] = None, - device: Literal["cpu", "cuda"] = "cuda", -): - torch.manual_seed(20) - is_fp8_dtype = is_dtype_fp8(dtype) - - if layout == "thd": - # set params - TOTAL_SEQLENS_Q = BATCH * N_CTX_Q - TOTAL_SEQLENS_K = BATCH * N_CTX_K - equal_seqlens = False - - # deal with packing - if packing is None: - # gen tensors - if is_fp8_dtype: - q, cu_seqlens_q, max_seqlen_q, descale_q = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - k, cu_seqlens_k, max_seqlen_k, descale_k = generate_varlen_tensor( - TOTAL_SEQLENS_K, - HK, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - v, _, _, descale_v = generate_varlen_tensor( - TOTAL_SEQLENS_K, - HK, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - do, _, _, descale_do = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - else: - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor( - TOTAL_SEQLENS_K, - HK, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - v, _, _ = generate_varlen_tensor( - TOTAL_SEQLENS_K, - HK, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - do, _, _ = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - elif packing == "kv": - # gen tensors with kv packing - if is_fp8_dtype: - raise ValueError("FP8 not supported for KV packing yet") - else: - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - kv, cu_seqlens_k, max_seqlen_k = generate_varlen_kv_packed( - TOTAL_SEQLENS_K, - HK, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - do, _, _ = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - elif packing == "qkv": - # qkv packing - requires same sequence length for q and k - assert ( - N_CTX_Q == N_CTX_K - ), "For QKV packing, Q and K must have same sequence length" - assert HQ == HK, "For QKV packing, Q and K must have same number of heads" - - if is_fp8_dtype: - raise ValueError("FP8 not supported for QKV packing yet") - else: - qkv, cu_seqlens_q, max_seqlen_q = generate_varlen_qkv_packed( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - cu_seqlens_k = cu_seqlens_q - max_seqlen_k = max_seqlen_q - do, _, _ = generate_varlen_tensor( - TOTAL_SEQLENS_Q, - HQ, - D_HEAD, - batch_size=BATCH, - dtype=dtype, - device=device, - equal_seqlens=equal_seqlens, - ) - - elif layout == "bshd" or layout == "bhsd": - # deal with packing - if packing is None: - # gen tensors - if layout == "bshd": - if is_fp8_dtype: - q, descale_q = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - k, descale_k = generate_bshd_tensor( - BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device - ) - v, descale_v = generate_bshd_tensor( - BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device - ) - do, descale_do = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - else: - q = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - k = generate_bshd_tensor( - BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device - ) - v = generate_bshd_tensor( - BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device - ) - do = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - elif layout == "bhsd": - q, descale_q = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - k, descale_k = generate_bhsd_tensor( - BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device - ) - v, descale_v = generate_bhsd_tensor( - BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device - ) - do, descale_do = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - else: - q = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - k = generate_bhsd_tensor( - BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device - ) - v = generate_bhsd_tensor( - BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device - ) - do = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - elif packing == "kv": - # gen tensors with kv packing - if is_fp8_dtype: - raise ValueError("FP8 not supported for KV packing yet") - else: - if layout == "bshd": - q = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - kv = generate_bshd_kv_packed( - BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device - ) - do = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - elif layout == "bhsd": - q = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - kv = generate_bhsd_kv_packed( - BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device - ) - do = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - elif packing == "qkv": - # qkv packing - requires same sequence length for q and k - assert ( - N_CTX_Q == N_CTX_K - ), "For QKV packing, Q and K must have same sequence length" - assert HQ == HK, "For QKV packing, Q and K must have same number of heads" - - if is_fp8_dtype: - raise ValueError("FP8 not supported for QKV packing yet") - else: - if layout == "bshd": - qkv = generate_bshd_qkv_packed( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - do = generate_bshd_tensor( - BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device - ) - elif layout == "bhsd": - qkv = generate_bhsd_qkv_packed( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - do = generate_bhsd_tensor( - BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device - ) - - else: - raise ValueError(f"Unknown layout: {layout}") - - # return based on packing - if packing is None: - if is_fp8_dtype: - return (q, descale_q), (k, descale_k), (v, descale_v), (do, descale_do) - else: - return q, k, v, do - elif packing == "kv": - if is_fp8_dtype: - raise ValueError("FP8 not supported kv packing yet") - else: - return q, kv, do - elif packing == "qkv": - if is_fp8_dtype: - raise ValueError("FP8 not supported qkv packing yet") - else: - return qkv, do - else: - assert False, f"Unsupported packing mode: {packing}" - - # ------------------------------- # Alibi # ------------------------------- From f4224ddf1408909f6c37c3dba87860bc809cb38d Mon Sep 17 00:00:00 2001 From: Michael Date: Fri, 10 Oct 2025 13:07:47 -0500 Subject: [PATCH 23/33] lint --- flash_attn/flash_attn_triton_amd/bwd.py | 434 +++++++++++++----- .../flash_attn_triton_amd/fwd_decode.py | 34 +- .../flash_attn_triton_amd/fwd_prefill.py | 160 +++++-- .../flash_attn_triton_amd/interface_v2.py | 33 +- .../flash_attn_triton_amd/interface_v3.py | 189 ++++++-- flash_attn/flash_attn_triton_amd/utils.py | 12 +- 6 files changed, 649 insertions(+), 213 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index bfdc7120f87..2becd57c1c7 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -27,17 +27,28 @@ def get_bwd_configs(autotune: bool): # keys preprocess_autotune_keys = [ "max_seqlen_q", - "ACTUAL_HEAD_DIM", "IS_VARLEN", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", ] - + causal_autotune_keys = [ - "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "dropout_p", + "max_seqlen_q", + "max_seqlen_k", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", + "HQ", + "HK", ] - + noncausal_autotune_keys = [ - "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "dropout_p", + "max_seqlen_q", + "max_seqlen_k", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", + "HQ", + "HK", ] # default config @@ -47,82 +58,264 @@ def get_bwd_configs(autotune: bool): if arch == "gfx942": if get_cu_count() < 304: preprocess_autotune_configs = [ - triton.Config({"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=8), - triton.Config({"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8), - triton.Config({"PRE_BLOCK": 128, "waves_per_eu": 2}, num_stages=1, num_warps=4), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 128, "waves_per_eu": 2}, num_stages=1, num_warps=4 + ), ] noncausal_autotune_configs = [ - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), - triton.Config({"BLOCK_M1": 64, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 2, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=8), - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=8), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=8, + ), ] causal_autotune_configs = [ - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), - triton.Config({"BLOCK_M1": 64, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), ] else: preprocess_autotune_configs = [ - triton.Config({"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8), - triton.Config({"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=4), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=4 + ), ] noncausal_autotune_configs = [ - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), - triton.Config({"BLOCK_M1": 64, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 2, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), ] causal_autotune_configs = [ - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), ] else: preprocess_autotune_configs = [ - triton.Config({"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), ] noncausal_autotune_configs = [ - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), ] causal_autotune_configs = [ - triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + }, + num_stages=1, + num_warps=4, + ), ] # assert constraints - for (noncausal_cfg, causal_cfg) in zip(noncausal_autotune_configs, causal_autotune_configs): - assert noncausal_cfg.all_kwargs()["BLOCK_N1"] == noncausal_cfg.all_kwargs()["BLOCK_M2"], f"BLOCK_N1 ({noncausal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({noncausal_cfg.all_kwargs()['BLOCK_M2']})" - assert causal_cfg.all_kwargs()["BLOCK_N1"] == causal_cfg.all_kwargs()["BLOCK_M2"], f"BLOCK_N1 ({causal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({causal_cfg.all_kwargs()['BLOCK_M2']})" - - return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) - + for noncausal_cfg, causal_cfg in zip( + noncausal_autotune_configs, causal_autotune_configs + ): + assert ( + noncausal_cfg.all_kwargs()["BLOCK_N1"] + == noncausal_cfg.all_kwargs()["BLOCK_M2"] + ), f"BLOCK_N1 ({noncausal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({noncausal_cfg.all_kwargs()['BLOCK_M2']})" + assert ( + causal_cfg.all_kwargs()["BLOCK_N1"] + == causal_cfg.all_kwargs()["BLOCK_M2"] + ), f"BLOCK_N1 ({causal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({causal_cfg.all_kwargs()['BLOCK_M2']})" + + return ( + (preprocess_autotune_configs, preprocess_autotune_keys), + (causal_autotune_configs, causal_autotune_keys), + (noncausal_autotune_configs, noncausal_autotune_keys), + ) # param options - PRE_BLOCK_OPTIONS = [64, 128] # og: 128 - PRE_WAVES_PER_EU_OPTIONS=[1, 2] - PRE_NUM_STAGES_OPTIONS=[1, 2] - PRE_NUM_WARPS_OPTIONS=[4, 8] - NUM_STAGES_OPTIONS = [1, 2] # og: 1 - NUM_WARPS_OPTIONS = [4, 8] # og: 4 - WAVES_PER_EU_OPTIONS = [1, 2] # og: 1 - MATRIX_INSTR_NONKDIM_OPTIONS = [16, 32] # og: 16 - CAUSAL_BLOCK_M1_OPTIONS = [ # og: 32 - 32, 64, - ] - CAUSAL_BLOCK_N1_M2_OPTIONS = [ # og: 128 - 64, 128, 256 + PRE_BLOCK_OPTIONS = [64, 128] # og: 128 + PRE_WAVES_PER_EU_OPTIONS = [1, 2] + PRE_NUM_STAGES_OPTIONS = [1, 2] + PRE_NUM_WARPS_OPTIONS = [4, 8] + NUM_STAGES_OPTIONS = [1, 2] # og: 1 + NUM_WARPS_OPTIONS = [4, 8] # og: 4 + WAVES_PER_EU_OPTIONS = [1, 2] # og: 1 + MATRIX_INSTR_NONKDIM_OPTIONS = [16, 32] # og: 16 + CAUSAL_BLOCK_M1_OPTIONS = [ # og: 32 + 32, + 64, ] - CAUSAL_BLOCK_N2_OPTIONS = [ # og: 32 - 32, 64 - ] - NON_CAUSAL_BLOCK_M1_OPTIONS = [ # og: 32 - 32, 64 - ] - NON_CAUSAL_BLOCK_N1_M2_OPTIONS = [ # og: 128 - 64, 128, 256 - ] - NON_CAUSAL_BLOCK_N2_OPTIONS = [ # og: 32 - 32, 64 - ] - BLK_SLICE_FACTOR_OPTIONS = [2] # og: 2 + CAUSAL_BLOCK_N1_M2_OPTIONS = [64, 128, 256] # og: 128 + CAUSAL_BLOCK_N2_OPTIONS = [32, 64] # og: 32 + NON_CAUSAL_BLOCK_M1_OPTIONS = [32, 64] # og: 32 + NON_CAUSAL_BLOCK_N1_M2_OPTIONS = [64, 128, 256] # og: 128 + NON_CAUSAL_BLOCK_N2_OPTIONS = [32, 64] # og: 32 + BLK_SLICE_FACTOR_OPTIONS = [2] # og: 2 # ==================== sweep configs ================================ preprocess_autotune_configs = [] @@ -131,10 +324,14 @@ def get_bwd_configs(autotune: bool): for pre_waves in PRE_WAVES_PER_EU_OPTIONS: for pre_block in PRE_BLOCK_OPTIONS: preprocess_autotune_configs.append( - triton.Config({ - "PRE_BLOCK": pre_block, - "waves_per_eu": pre_waves, - }, num_stages=pre_num_stages, num_warps=pre_num_warps) + triton.Config( + { + "PRE_BLOCK": pre_block, + "waves_per_eu": pre_waves, + }, + num_stages=pre_num_stages, + num_warps=pre_num_warps, + ) ) causal_autotune_configs = [] @@ -147,20 +344,28 @@ def get_bwd_configs(autotune: bool): m2 = n1 for n2 in CAUSAL_BLOCK_N2_OPTIONS: # Ensure constraint - assert n1 == m2, f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" - + assert ( + n1 == m2 + ), f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + for blk_slice in BLK_SLICE_FACTOR_OPTIONS: causal_autotune_configs.append( - triton.Config({ - "BLOCK_M1": m1, "BLOCK_N1": n1, - "BLOCK_M2": m2, "BLOCK_N2": n2, - "BLK_SLICE_FACTOR": blk_slice, - "waves_per_eu": waves, - "matrix_instr_nonkdim": matrix_instr_nonkdim - }, num_stages=num_stages, num_warps=num_warps) + triton.Config( + { + "BLOCK_M1": m1, + "BLOCK_N1": n1, + "BLOCK_M2": m2, + "BLOCK_N2": n2, + "BLK_SLICE_FACTOR": blk_slice, + "waves_per_eu": waves, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + }, + num_stages=num_stages, + num_warps=num_warps, + ) ) - noncausal_autotune_configs = [] + noncausal_autotune_configs = [] for num_warps in NUM_WARPS_OPTIONS: for num_stages in NUM_STAGES_OPTIONS: for waves in WAVES_PER_EU_OPTIONS: @@ -170,21 +375,32 @@ def get_bwd_configs(autotune: bool): m2 = n1 for n2 in NON_CAUSAL_BLOCK_N2_OPTIONS: # Ensure constraint - assert n1 == m2, f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + assert ( + n1 == m2 + ), f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" for blk_slice in BLK_SLICE_FACTOR_OPTIONS: noncausal_autotune_configs.append( - triton.Config({ - "BLOCK_M1": m1, "BLOCK_N1": n1, - "BLOCK_M2": m2, "BLOCK_N2": n2, - "BLK_SLICE_FACTOR": blk_slice, - "waves_per_eu": waves, - "matrix_instr_nonkdim": matrix_instr_nonkdim - }, num_stages=num_stages, num_warps=num_warps) + triton.Config( + { + "BLOCK_M1": m1, + "BLOCK_N1": n1, + "BLOCK_M2": m2, + "BLOCK_N2": n2, + "BLK_SLICE_FACTOR": blk_slice, + "waves_per_eu": waves, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + }, + num_stages=num_stages, + num_warps=num_warps, + ) ) - return (preprocess_autotune_configs, preprocess_autotune_keys), \ - (causal_autotune_configs, causal_autotune_keys), \ - (noncausal_autotune_configs, noncausal_autotune_keys) + return ( + (preprocess_autotune_configs, preprocess_autotune_keys), + (causal_autotune_configs, causal_autotune_keys), + (noncausal_autotune_configs, noncausal_autotune_keys), + ) + # os.environ["TRITON_PRINT_AUTOTUNING"] = "1" ( @@ -659,9 +875,7 @@ def _bwd_dkdvdq_inner_atomic( # NOTE: Possible problems with the atomic add: contention, is inside a loop which has achieved bad perf before # (BLOCK_M, BLOCK_N) x (BLOCK_N, D) if IS_FP8: - dq_partial = ( - tl.dot(dsT.to(k.type.element_ty).T, k) * descale_k - ) + dq_partial = tl.dot(dsT.to(k.type.element_ty).T, k) * descale_k else: dq_partial = tl.dot(dsT.to(k.type.element_ty).T, k) tl.atomic_add( @@ -3618,9 +3832,7 @@ def attention_backward_triton_impl( assert ( q.device == k.device == v.device == o.device == do.device == softmax_lse.device ), f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}, do={do.device}, softmax_lse={softmax_lse.device}" - assert ( - q.dtype == k.dtype == v.dtype - ), "q, k, v must have the same dtype" + assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype" current_device = torch.cuda.current_device() assert ( q.is_cuda and q.device.index == current_device @@ -3826,18 +4038,12 @@ def attention_backward_triton_impl( ) # For GQA/MQA, q_descale should be shaped (batch, nheads_k) to match forward pass - descale_q = torch.ones( - batch, nheads_k, dtype=torch.float32, device=q.device - ) - - descale_k = torch.ones( - batch, nheads_k, dtype=torch.float32, device=q.device - ) - - descale_v = torch.ones( - batch, nheads_k, dtype=torch.float32, device=q.device - ) - + descale_q = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) + + descale_k = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) + + descale_v = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) + stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None @@ -3868,8 +4074,12 @@ def attention_backward_triton_impl( if IS_VARLEN: # Shape expected by interface varlen backward: (Hq, Total_Q) total_q, _, _ = q.shape - assert delta.shape[0] == nheads_q, f"delta.shape[0] ({delta.shape[0]}) must equal nheads_q ({nheads_q})" - assert delta.shape[1] >= total_q, f"delta.shape[1] ({delta.shape[1]}) must be >= total_q ({total_q})" + assert ( + delta.shape[0] == nheads_q + ), f"delta.shape[0] ({delta.shape[0]}) must equal nheads_q ({nheads_q})" + assert ( + delta.shape[1] >= total_q + ), f"delta.shape[1] ({delta.shape[1]}) must be >= total_q ({total_q})" assert delta.dtype == torch.float32, f"delta must be float32, got {delta.dtype}" assert delta.device == q.device, f"delta must be on same device as q" stride_delta_b, stride_delta_h, stride_delta_m = ( @@ -3880,9 +4090,15 @@ def attention_backward_triton_impl( else: # Shape expected by dense backward: (B, Hq, Sq) seqlen_q = q.shape[1] - assert delta.shape[0] == batch, f"delta.shape[0] ({delta.shape[0]}) must equal batch ({batch})" - assert delta.shape[1] == nheads_q, f"delta.shape[1] ({delta.shape[1]}) must equal nheads_q ({nheads_q})" - assert delta.shape[2] >= seqlen_q, f"delta.shape[2] ({delta.shape[2]}) must be >= seqlen_q ({seqlen_q})" + assert ( + delta.shape[0] == batch + ), f"delta.shape[0] ({delta.shape[0]}) must equal batch ({batch})" + assert ( + delta.shape[1] == nheads_q + ), f"delta.shape[1] ({delta.shape[1]}) must equal nheads_q ({nheads_q})" + assert ( + delta.shape[2] >= seqlen_q + ), f"delta.shape[2] ({delta.shape[2]}) must be >= seqlen_q ({seqlen_q})" assert delta.dtype == torch.float32, f"delta must be float32, got {delta.dtype}" assert delta.device == q.device, f"delta must be on same device as q" stride_delta_b, stride_delta_h, stride_delta_m = delta.stride() @@ -4137,7 +4353,7 @@ def attention_backward_triton_impl( grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) - + # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups BLOCK_N = ( 128 if BLOCK_D_MODEL_POW2 < 160 else 64 @@ -4293,7 +4509,7 @@ def attention_backward_triton_impl( grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) - + if causal: _bwd_kernel_split_dkdv_causal[grid_dkdv]( q, diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index fb096de96db..4645dcc97fe 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -1030,13 +1030,13 @@ def attention_forward_decode_triton_impl( stride_kn_h, stride_kn_n, stride_kn_d, - ) = (None, None, None, None), (None, None, None, None) + ) = (None, None, None, None,), (None, None, None, None) (_, seqlen_vn, nheads_vn, dim_vn), ( stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d, - ) = (None, None, None, None), (None, None, None, None) + ) = (None, None, None, None,), (None, None, None, None) (_, seqlen_o, nheads_o, dim_o), (stride_oz, stride_oh, stride_om, stride_od) = ( get_shape_and_strides_from_layout(out, layout) ) @@ -1105,17 +1105,25 @@ def attention_forward_decode_triton_impl( dtype=torch.float32, device=q.device, ) - + # Validate pre-allocated softmax_lse tensor # Expected shape after view: (batch_size, n_group_q * heads_per_group_q, seqlen_q) # Internal shape: (batch_size * n_group_q * heads_per_group_q, seqlen_q) expected_h_total = batch_size * n_group_q * heads_per_group_q - assert softmax_lse.shape[0] == batch_size, f"softmax_lse.shape[0] ({softmax_lse.shape[0]}) must equal batch_size ({batch_size})" - assert softmax_lse.shape[1] == n_group_q * heads_per_group_q, f"softmax_lse.shape[1] ({softmax_lse.shape[1]}) must equal n_group_q * heads_per_group_q ({n_group_q * heads_per_group_q})" - assert softmax_lse.shape[2] >= seqlen_q, f"softmax_lse.shape[2] ({softmax_lse.shape[2]}) must be >= seqlen_q ({seqlen_q})" - assert softmax_lse.dtype == torch.float32, f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert ( + softmax_lse.shape[0] == batch_size + ), f"softmax_lse.shape[0] ({softmax_lse.shape[0]}) must equal batch_size ({batch_size})" + assert ( + softmax_lse.shape[1] == n_group_q * heads_per_group_q + ), f"softmax_lse.shape[1] ({softmax_lse.shape[1]}) must equal n_group_q * heads_per_group_q ({n_group_q * heads_per_group_q})" + assert ( + softmax_lse.shape[2] >= seqlen_q + ), f"softmax_lse.shape[2] ({softmax_lse.shape[2]}) must be >= seqlen_q ({seqlen_q})" + assert ( + softmax_lse.dtype == torch.float32 + ), f"softmax_lse must be float32, got {softmax_lse.dtype}" assert softmax_lse.device == q.device, f"softmax_lse must be on same device as q" - + # Create internal lse view for kernel use lse = softmax_lse.view(expected_h_total, -1)[:, :seqlen_q].contiguous() @@ -1134,11 +1142,15 @@ def attention_forward_decode_triton_impl( IS_FP8 = is_fp8([q, k_cache, v_cache]) if IS_FP8: rec_dtype = get_recommended_fp8_dtype(q) - if q.dtype != rec_dtype or k_cache.dtype != rec_dtype or v_cache.dtype != rec_dtype: + if ( + q.dtype != rec_dtype + or k_cache.dtype != rec_dtype + or v_cache.dtype != rec_dtype + ): arch = get_arch() warnings.warn( - f"Use {rec_dtype} data type on {arch}. Got q: {q.dtype}, k: {k_cache.dtype}, v: {v_cache.dtype}", - UserWarning, + f"Use {rec_dtype} data type on {arch}. Got q: {q.dtype}, k: {k_cache.dtype}, v: {v_cache.dtype}", + UserWarning, ) if (q_descale is None) or (k_descale is None) or (v_descale is None): warnings.warn( diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 506cfba02f9..b976b0a1cac 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -44,56 +44,101 @@ def get_fwd_configs(autotune: bool): if not autotune: arch = get_arch() if arch == "gfx950": - configs.append(triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - )) + configs.append( + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ) + ) elif arch == "gfx942": if get_cu_count() < 304: configs.extend( [ # best fp8 config triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, num_stages=1, num_warps=4, ), # best f16 config triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, + { + "BLOCK_M": 128, + "BLOCK_N": 32, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, num_stages=2, num_warps=4, - ) + ), ] ) else: - configs.append(triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + configs.append( + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ) + ) + elif arch in ( + "gfx1030", + "gfx1100", + "gfx1101", + "gfx1102", + "gfx1200", + "gfx1201", + ): # RDNA architectures + configs.append( + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, num_stages=1, - num_warps=4, - )) - elif arch in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"): # RDNA architectures - configs.append(triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=2, - )) + num_warps=2, + ) + ) else: - configs.append(triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - )) + configs.append( + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ) + ) return configs, keys # ===================== Autotune Sweep ===================== BLOCK_M_OPTIONS = [128, 64, 32] BLOCK_N_OPTIONS = [128, 64, 32] - NUM_WARPS_OPTIONS = [2, 4, 8] - NUM_STAGES_OPTIONS = [1, 2] - WAVES_PER_EU_OPTIONS = [4, 2, 1] + NUM_WARPS_OPTIONS = [2, 4, 8] + NUM_STAGES_OPTIONS = [1, 2] + WAVES_PER_EU_OPTIONS = [4, 2, 1] PRE_LOAD_V_OPTIONS = [False] for bm in BLOCK_M_OPTIONS: for bn in BLOCK_N_OPTIONS: @@ -116,6 +161,7 @@ def get_fwd_configs(autotune: bool): return configs, keys + fwd_prefill_autotune_configs, fwd_prefill_autotune_keys = get_fwd_configs(AUTOTUNE) @@ -1611,10 +1657,18 @@ def attention_forward_prefill_triton_impl( head_size_qk = head_size_q # Assert softmax_lse tensor is large enough - assert softmax_lse.shape[0] >= nheads_q, f"softmax_lse.shape[0]={softmax_lse.shape[0]} must be >= nheads_q={nheads_q}" - assert softmax_lse.shape[1] >= total_seqlen_q, f"softmax_lse.shape[1]={softmax_lse.shape[1]} must be >= total_seqlen_q={total_seqlen_q}" - assert softmax_lse.dtype == torch.float32, f"softmax_lse must be float32, got {softmax_lse.dtype}" - assert softmax_lse.device == q.device, f"softmax_lse must be on same device as q" + assert ( + softmax_lse.shape[0] >= nheads_q + ), f"softmax_lse.shape[0]={softmax_lse.shape[0]} must be >= nheads_q={nheads_q}" + assert ( + softmax_lse.shape[1] >= total_seqlen_q + ), f"softmax_lse.shape[1]={softmax_lse.shape[1]} must be >= total_seqlen_q={total_seqlen_q}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert ( + softmax_lse.device == q.device + ), f"softmax_lse must be on same device as q" # strides stride_qb, stride_qh, stride_qm, stride_qd = ( @@ -1688,11 +1742,21 @@ def attention_forward_prefill_triton_impl( max_seqlens_k = seqlen_k # Assert softmax_lse tensor is large enough - assert softmax_lse.shape[0] >= batch, f"softmax_lse.shape[0]={softmax_lse.shape[0]} must be >= batch={batch}" - assert softmax_lse.shape[1] >= nheads_q, f"softmax_lse.shape[1]={softmax_lse.shape[1]} must be >= nheads_q={nheads_q}" - assert softmax_lse.shape[2] >= seqlen_q, f"softmax_lse.shape[2]={softmax_lse.shape[2]} must be >= seqlen_q={seqlen_q}" - assert softmax_lse.dtype == torch.float32, f"softmax_lse must be float32, got {softmax_lse.dtype}" - assert softmax_lse.device == q.device, f"softmax_lse must be on same device as q" + assert ( + softmax_lse.shape[0] >= batch + ), f"softmax_lse.shape[0]={softmax_lse.shape[0]} must be >= batch={batch}" + assert ( + softmax_lse.shape[1] >= nheads_q + ), f"softmax_lse.shape[1]={softmax_lse.shape[1]} must be >= nheads_q={nheads_q}" + assert ( + softmax_lse.shape[2] >= seqlen_q + ), f"softmax_lse.shape[2]={softmax_lse.shape[2]} must be >= seqlen_q={seqlen_q}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert ( + softmax_lse.device == q.device + ), f"softmax_lse must be on same device as q" # strides stride_qb, stride_qh, stride_qm, stride_qd = ( @@ -1834,15 +1898,27 @@ def attention_forward_prefill_triton_impl( # only. This return holds no useful output aside from debugging. NEEDS_SDMASK = (dropout_p > 0.0) or return_softmax if NEEDS_SDMASK: - assert sd_mask is not None, "sd_mask must be provided when return_softmax=True or dropout_p > 0" + assert ( + sd_mask is not None + ), "sd_mask must be provided when return_softmax=True or dropout_p > 0" # Assert sd_mask tensor is large enough - assert sd_mask.shape[0] >= batch, f"sd_mask.shape[0]={sd_mask.shape[0]} must be >= batch={batch}" - assert sd_mask.shape[1] >= nheads_q, f"sd_mask.shape[1]={sd_mask.shape[1]} must be >= nheads_q={nheads_q}" - assert sd_mask.shape[2] >= max_seqlens_q, f"sd_mask.shape[2]={sd_mask.shape[2]} must be >= max_seqlens_q={max_seqlens_q}" - assert sd_mask.shape[3] >= max_seqlens_k, f"sd_mask.shape[3]={sd_mask.shape[3]} must be >= max_seqlens_k={max_seqlens_k}" - assert sd_mask.dtype == torch.float32, f"sd_mask must be float32, got {sd_mask.dtype}" + assert ( + sd_mask.shape[0] >= batch + ), f"sd_mask.shape[0]={sd_mask.shape[0]} must be >= batch={batch}" + assert ( + sd_mask.shape[1] >= nheads_q + ), f"sd_mask.shape[1]={sd_mask.shape[1]} must be >= nheads_q={nheads_q}" + assert ( + sd_mask.shape[2] >= max_seqlens_q + ), f"sd_mask.shape[2]={sd_mask.shape[2]} must be >= max_seqlens_q={max_seqlens_q}" + assert ( + sd_mask.shape[3] >= max_seqlens_k + ), f"sd_mask.shape[3]={sd_mask.shape[3]} must be >= max_seqlens_k={max_seqlens_k}" + assert ( + sd_mask.dtype == torch.float32 + ), f"sd_mask must be float32, got {sd_mask.dtype}" assert sd_mask.device == q.device, f"sd_mask must be on same device as q" - + if DROPOUT_USE_PYTORCH: dropout_mask = create_dropout_mask( dropout_p, diff --git a/flash_attn/flash_attn_triton_amd/interface_v2.py b/flash_attn/flash_attn_triton_amd/interface_v2.py index 48f0ed45f8b..dfa3c4fdae6 100644 --- a/flash_attn/flash_attn_triton_amd/interface_v2.py +++ b/flash_attn/flash_attn_triton_amd/interface_v2.py @@ -4,7 +4,15 @@ from .fwd_prefill import attention_forward_prefill_triton_impl from .fwd_decode import attention_forward_decode_triton_impl from .bwd import attention_backward_triton_impl -from .utils import DEBUG, USE_EXP2, BWD_MODE, PHILOX_SEED, PHILOX_OFFSET, SHAPE_EXPECTATIONS, round_multiple +from .utils import ( + DEBUG, + USE_EXP2, + BWD_MODE, + PHILOX_SEED, + PHILOX_OFFSET, + SHAPE_EXPECTATIONS, + round_multiple, +) def fwd( @@ -86,7 +94,12 @@ def fwd( ) if return_softmax: sd_mask = torch.zeros( - (batch, nheads_q, round_multiple(max_seqlen_q, 128), round_multiple(max_seqlen_k, 128)), + ( + batch, + nheads_q, + round_multiple(max_seqlen_q, 128), + round_multiple(max_seqlen_k, 128), + ), device=q.device, dtype=torch.float32, ) @@ -254,7 +267,9 @@ def bwd( dtype=torch.float32, ) else: - delta = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) + delta = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) # Upstream change: base seeding logic on provided rng_state instead of dropout probability. if rng_state is not None: @@ -490,8 +505,12 @@ def varlen_fwd( ), "[varlen_fwd] return_softmax=True but sd_mask is None" assert sd_mask.dim() == 4, f"[varlen_fwd] sd_mask dim {sd_mask.dim()} != 4" batch = len(cu_seqlens_q) - 1 - assert sd_mask.shape[0] == batch, f"[varlen_fwd] sd_mask batch {sd_mask.shape[0]} != {batch}" - assert sd_mask.shape[1] == q.shape[1], f"[varlen_fwd] sd_mask nheads {sd_mask.shape[1]} != {q.shape[1]}" + assert ( + sd_mask.shape[0] == batch + ), f"[varlen_fwd] sd_mask batch {sd_mask.shape[0]} != {batch}" + assert ( + sd_mask.shape[1] == q.shape[1] + ), f"[varlen_fwd] sd_mask nheads {sd_mask.shape[1]} != {q.shape[1]}" if SHAPE_EXPECTATIONS == "rounded": expected_sq = round_multiple(max_seqlen_q, 128) expected_sk = round_multiple(max_seqlen_k, 128) @@ -739,7 +758,9 @@ def fwd_kvcache( batch, seqlen_q, nheads_q, _ = q.shape # Create softmax_lse tensor - decode always uses exact shape (B, Hq, Sq) - softmax_lse = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) + softmax_lse = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) if alibi_slopes is not None: if alibi_slopes.dim() == 1: diff --git a/flash_attn/flash_attn_triton_amd/interface_v3.py b/flash_attn/flash_attn_triton_amd/interface_v3.py index 207c9ed2e52..3ed35de5cd1 100755 --- a/flash_attn/flash_attn_triton_amd/interface_v3.py +++ b/flash_attn/flash_attn_triton_amd/interface_v3.py @@ -64,10 +64,26 @@ def fwd( print("q:", q.dtype if q is not None else None, q.shape) print("k:", k.dtype if k is not None else None, k.shape) print("v:", v.dtype if v is not None else None, v.shape) - print("k_new:", k_new.dtype if k_new is not None else None, k_new.shape if k_new is not None else None) - print("v_new:", v_new.dtype if v_new is not None else None, v_new.shape if v_new is not None else None) - print("qv:", qv.dtype if qv is not None else None, qv.shape if qv is not None else None) - print("out:", out.dtype if out is not None else None, out.shape if out is not None else None) + print( + "k_new:", + k_new.dtype if k_new is not None else None, + k_new.shape if k_new is not None else None, + ) + print( + "v_new:", + v_new.dtype if v_new is not None else None, + v_new.shape if v_new is not None else None, + ) + print( + "qv:", + qv.dtype if qv is not None else None, + qv.shape if qv is not None else None, + ) + print( + "out:", + out.dtype if out is not None else None, + out.shape if out is not None else None, + ) print( "cu_seqlens_q:", cu_seqlens_q, @@ -120,13 +136,19 @@ def fwd( seqlens_rotary.shape if seqlens_rotary is not None else None, ) print( - "q_descale:", q_descale.dtype if q_descale is not None else None, q_descale.shape if q_descale is not None else None + "q_descale:", + q_descale.dtype if q_descale is not None else None, + q_descale.shape if q_descale is not None else None, ) print( - "k_descale:", k_descale.dtype if k_descale is not None else None, k_descale.shape if k_descale is not None else None + "k_descale:", + k_descale.dtype if k_descale is not None else None, + k_descale.shape if k_descale is not None else None, ) print( - "v_descale:", v_descale.dtype if v_descale is not None else None, v_descale.shape if v_descale is not None else None + "v_descale:", + v_descale.dtype if v_descale is not None else None, + v_descale.shape if v_descale is not None else None, ) print("softmax_scale:", softmax_scale) print("causal:", causal) @@ -269,7 +291,6 @@ def fwd( else: out = out.zero_() - # Handle causal mask causal_flag = bool(causal) @@ -294,7 +315,9 @@ def fwd( # Create softmax_lse tensor for decode - always exact shape (B, Hq, Sq) batch, seqlen_q, nheads_q, _ = q.shape - softmax_lse = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) + softmax_lse = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) attention_forward_decode_triton_impl( q, @@ -324,20 +347,24 @@ def fwd( else: if DEBUG: print("Using Prefill Triton implementation") - + # Create softmax_lse tensor - FA3 always uses exact shapes if layout == "thd": # varlen: (Hq, Total_Q) total_q, nheads_q, _ = q.shape - softmax_lse = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + softmax_lse = torch.zeros( + (nheads_q, total_q), device=q.device, dtype=torch.float32 + ) else: # bshd: (B, Hq, Sq) batch, seqlen_q, nheads_q, _ = q.shape - softmax_lse = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) - + softmax_lse = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + # sd_mask is not returned in v3 interface sd_mask = None - + attention_forward_prefill_triton_impl( q, k, @@ -374,25 +401,49 @@ def fwd( if DEBUG: print("interface_fa_v3.py::fwd outputs") - print("out:", out.dtype if out is not None else None, out.shape if out is not None else None) - print("softmax_lse:", softmax_lse.dtype if softmax_lse is not None else None, softmax_lse.shape if softmax_lse is not None else None) + print( + "out:", + out.dtype if out is not None else None, + out.shape if out is not None else None, + ) + print( + "softmax_lse:", + softmax_lse.dtype if softmax_lse is not None else None, + softmax_lse.shape if softmax_lse is not None else None, + ) # --- Assertions (FA3 always expects exact shapes) --- # out: same shape as q except last dim is v's head_dim if layout == "thd": # varlen: (Total_Q, Hq, Dv) - assert out.shape[0] == q.shape[0], f"[fwd_v3] out.shape[0] {out.shape[0]} != q.shape[0] {q.shape[0]}" - assert out.shape[1] == q.shape[1], f"[fwd_v3] out.shape[1] {out.shape[1]} != q.shape[1] {q.shape[1]}" - assert out.shape[2] == v.shape[-1], f"[fwd_v3] out.shape[2] {out.shape[2]} != v.shape[-1] {v.shape[-1]}" + assert ( + out.shape[0] == q.shape[0] + ), f"[fwd_v3] out.shape[0] {out.shape[0]} != q.shape[0] {q.shape[0]}" + assert ( + out.shape[1] == q.shape[1] + ), f"[fwd_v3] out.shape[1] {out.shape[1]} != q.shape[1] {q.shape[1]}" + assert ( + out.shape[2] == v.shape[-1] + ), f"[fwd_v3] out.shape[2] {out.shape[2]} != v.shape[-1] {v.shape[-1]}" else: # bshd: (B, Sq, Hq, Dv) - assert out.shape[0] == q.shape[0], f"[fwd_v3] out.shape[0] {out.shape[0]} != q.shape[0] {q.shape[0]}" - assert out.shape[1] == q.shape[1], f"[fwd_v3] out.shape[1] {out.shape[1]} != q.shape[1] {q.shape[1]}" - assert out.shape[2] == q.shape[2], f"[fwd_v3] out.shape[2] {out.shape[2]} != q.shape[2] {q.shape[2]}" - assert out.shape[3] == v.shape[-1], f"[fwd_v3] out.shape[3] {out.shape[3]} != v.shape[-1] {v.shape[-1]}" - + assert ( + out.shape[0] == q.shape[0] + ), f"[fwd_v3] out.shape[0] {out.shape[0]} != q.shape[0] {q.shape[0]}" + assert ( + out.shape[1] == q.shape[1] + ), f"[fwd_v3] out.shape[1] {out.shape[1]} != q.shape[1] {q.shape[1]}" + assert ( + out.shape[2] == q.shape[2] + ), f"[fwd_v3] out.shape[2] {out.shape[2]} != q.shape[2] {q.shape[2]}" + assert ( + out.shape[3] == v.shape[-1] + ), f"[fwd_v3] out.shape[3] {out.shape[3]} != v.shape[-1] {v.shape[-1]}" + # softmax_lse dtype - assert softmax_lse.dtype == torch.float32, f"[fwd_v3] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + assert ( + softmax_lse.dtype == torch.float32 + ), f"[fwd_v3] softmax_lse dtype {softmax_lse.dtype} != torch.float32" # softmax_lse shape depends on layout if layout == "thd": # varlen: (Hq, Total_Q) @@ -400,7 +451,9 @@ def fwd( else: # bshd: (B, Hq, Sq) expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) - assert softmax_lse.shape == expected_lse_shape, f"[fwd_v3] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[fwd_v3] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" # Return format compatible with v3 # V3 returns (out, softmax_lse, *rest) where rest can be empty or contain additional outputs @@ -440,15 +493,45 @@ def bwd( if DEBUG: print() print("interface_fa_v3.py::bwd inputs") - print("dout:", dout.dtype if dout is not None else None, dout.shape if dout is not None else None) - print("q:", q.dtype if q is not None else None, q.shape if q is not None else None) - print("k:", k.dtype if k is not None else None, k.shape if k is not None else None) - print("v:", v.dtype if v is not None else None, v.shape if v is not None else None) - print("out:", out.dtype if out is not None else None, out.shape if out is not None else None) - print("softmax_lse:", softmax_lse.dtype if softmax_lse is not None else None, softmax_lse.shape if softmax_lse is not None else None) - print("dq:", dq.dtype if dq is not None else None, dq.shape if dq is not None else None) - print("dk:", dk.dtype if dk is not None else None, dk.shape if dk is not None else None) - print("dv:", dv.dtype if dv is not None else None, dv.shape if dv is not None else None) + print( + "dout:", + dout.dtype if dout is not None else None, + dout.shape if dout is not None else None, + ) + print( + "q:", q.dtype if q is not None else None, q.shape if q is not None else None + ) + print( + "k:", k.dtype if k is not None else None, k.shape if k is not None else None + ) + print( + "v:", v.dtype if v is not None else None, v.shape if v is not None else None + ) + print( + "out:", + out.dtype if out is not None else None, + out.shape if out is not None else None, + ) + print( + "softmax_lse:", + softmax_lse.dtype if softmax_lse is not None else None, + softmax_lse.shape if softmax_lse is not None else None, + ) + print( + "dq:", + dq.dtype if dq is not None else None, + dq.shape if dq is not None else None, + ) + print( + "dk:", + dk.dtype if dk is not None else None, + dk.shape if dk is not None else None, + ) + print( + "dv:", + dv.dtype if dv is not None else None, + dv.shape if dv is not None else None, + ) print( "cu_seqlens_q:", cu_seqlens_q, @@ -511,7 +594,9 @@ def bwd( max_seqlen_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q max_seqlen_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k # Create delta tensor - bshd: (B, Hq, Sq) - delta = torch.zeros((batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32) + delta = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) # V3 backward doesn't have dropout or alibi slopes dropout_p = 0.0 @@ -551,10 +636,26 @@ def bwd( if DEBUG: print("interface_fa_v3.py::bwd outputs") - print("dq:", dq.dtype if dq is not None else None, dq.shape if dq is not None else None) - print("dk:", dk.dtype if dk is not None else None, dk.shape if dk is not None else None) - print("dv:", dv.dtype if dv is not None else None, dv.shape if dv is not None else None) - print("delta:", delta.dtype if delta is not None else None, delta.shape if delta is not None else None) + print( + "dq:", + dq.dtype if dq is not None else None, + dq.shape if dq is not None else None, + ) + print( + "dk:", + dk.dtype if dk is not None else None, + dk.shape if dk is not None else None, + ) + print( + "dv:", + dv.dtype if dv is not None else None, + dv.shape if dv is not None else None, + ) + print( + "delta:", + delta.dtype if delta is not None else None, + delta.shape if delta is not None else None, + ) # --- Assertions (FA3 always expects exact shapes) --- # Gradients should match input shapes @@ -562,14 +663,18 @@ def bwd( assert dk.shape == k.shape, f"[bwd_v3] dk shape {dk.shape} != k shape {k.shape}" assert dv.shape == v.shape, f"[bwd_v3] dv shape {dv.shape} != v shape {v.shape}" # delta (softmax_d) should match softmax_lse shape - assert delta.dtype == torch.float32, f"[bwd_v3] delta dtype {delta.dtype} != torch.float32" + assert ( + delta.dtype == torch.float32 + ), f"[bwd_v3] delta dtype {delta.dtype} != torch.float32" if layout == "thd": # varlen: (Hq, Total_Q) expected_delta_shape = (q.shape[1], q.shape[0]) else: # bshd: (B, Hq, Sq) expected_delta_shape = (q.shape[0], q.shape[2], q.shape[1]) - assert delta.shape == expected_delta_shape, f"[bwd_v3] delta shape {delta.shape} != {expected_delta_shape}" + assert ( + delta.shape == expected_delta_shape + ), f"[bwd_v3] delta shape {delta.shape} != {expected_delta_shape}" # V3 expects (dq, dk, dv, softmax_d, *rest) # delta is the softmax_d in this case diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 1f8fd86d1f7..27ae6cae3eb 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -555,6 +555,7 @@ def generate_varlen_kv_packed( x.requires_grad_() return x, cu_seqlens, max_seqlen + # ------------------------------- # Alibi # ------------------------------- @@ -608,7 +609,6 @@ def is_dtype_fp8(dtype) -> bool: return True - _RECOMMENDED_FP8_REPLACEMENTS = { "gfx942": { torch.float8_e4m3fn: torch.float8_e4m3fnuz, @@ -616,6 +616,7 @@ def is_dtype_fp8(dtype) -> bool: }, } + def get_recommended_fp8_dtype(x): dtype = x.dtype if isinstance(x, torch.Tensor) else x if not is_dtype_fp8(dtype): @@ -623,6 +624,7 @@ def get_recommended_fp8_dtype(x): arch = get_arch() return _RECOMMENDED_FP8_REPLACEMENTS.get(arch, {}).get(dtype, dtype) + def is_fp8(x) -> bool: """Return whether tensor(s) use FP8. @@ -1491,9 +1493,13 @@ def is_hip(): def get_arch(): return triton.runtime.driver.active.get_current_target().arch + @functools.cache def get_cu_count(): - return torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count + return torch.cuda.get_device_properties( + torch.cuda.current_device() + ).multi_processor_count + @functools.cache def is_cdna(): @@ -1516,4 +1522,4 @@ def is_rdna(): "gfx1102", "gfx1200", "gfx1201", - ) \ No newline at end of file + ) From c478b8f02a139c78d13cd13dfee8de6140a02177 Mon Sep 17 00:00:00 2001 From: Michael Date: Fri, 10 Oct 2025 17:29:23 -0500 Subject: [PATCH 24/33] clean --- flash_attn/flash_attn_triton_amd/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 27ae6cae3eb..b6eb5ed025a 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -1271,9 +1271,7 @@ def _apply_rotary_kernel( batch, ) - # NOTE: We assume CUDA device indexing compatibility in upstream; adapt for ROCm by using device context. - # For ROCm, torch.cuda.device works if HIP_VISIBLE_DEVICES mapping is set. - with torch.cuda.device(x.device.index): # Works for ROCm as alias + with torch.cuda.device(x.device.index): torch.library.wrap_triton(_rotary_kernel)[grid]( out, x, From f25ff54c20e7fb10ad188298106ea5cbc4c14c6e Mon Sep 17 00:00:00 2001 From: Michael Date: Fri, 10 Oct 2025 20:30:43 -0500 Subject: [PATCH 25/33] start tuning gfx950 --- flash_attn/flash_attn_triton_amd/bwd.py | 58 +++++++++++++++++++++---- 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 2becd57c1c7..18c4eb96b9b 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -241,6 +241,55 @@ def get_bwd_configs(autotune: bool): num_warps=4, ), ] + elif arch == "gfx950": + preprocess_autotune_configs = [ + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=1, num_warps=8 + ), + ] + noncausal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ) + ] + causal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] else: preprocess_autotune_configs = [ triton.Config( @@ -256,7 +305,6 @@ def get_bwd_configs(autotune: bool): "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, - "matrix_instr_nonkdim": 16, }, num_stages=1, num_warps=4, @@ -271,7 +319,6 @@ def get_bwd_configs(autotune: bool): "BLOCK_N2": 64, "BLK_SLICE_FACTOR": 2, "waves_per_eu": 1, - "matrix_instr_nonkdim": 16, }, num_stages=1, num_warps=4, @@ -305,7 +352,6 @@ def get_bwd_configs(autotune: bool): NUM_STAGES_OPTIONS = [1, 2] # og: 1 NUM_WARPS_OPTIONS = [4, 8] # og: 4 WAVES_PER_EU_OPTIONS = [1, 2] # og: 1 - MATRIX_INSTR_NONKDIM_OPTIONS = [16, 32] # og: 16 CAUSAL_BLOCK_M1_OPTIONS = [ # og: 32 32, 64, @@ -338,7 +384,6 @@ def get_bwd_configs(autotune: bool): for num_warps in NUM_WARPS_OPTIONS: for num_stages in NUM_STAGES_OPTIONS: for waves in WAVES_PER_EU_OPTIONS: - for matrix_instr_nonkdim in MATRIX_INSTR_NONKDIM_OPTIONS: for m1 in CAUSAL_BLOCK_M1_OPTIONS: for n1 in CAUSAL_BLOCK_N1_M2_OPTIONS: m2 = n1 @@ -358,7 +403,6 @@ def get_bwd_configs(autotune: bool): "BLOCK_N2": n2, "BLK_SLICE_FACTOR": blk_slice, "waves_per_eu": waves, - "matrix_instr_nonkdim": matrix_instr_nonkdim, }, num_stages=num_stages, num_warps=num_warps, @@ -369,7 +413,6 @@ def get_bwd_configs(autotune: bool): for num_warps in NUM_WARPS_OPTIONS: for num_stages in NUM_STAGES_OPTIONS: for waves in WAVES_PER_EU_OPTIONS: - for matrix_instr_nonkdim in MATRIX_INSTR_NONKDIM_OPTIONS: for m1 in NON_CAUSAL_BLOCK_M1_OPTIONS: for n1 in NON_CAUSAL_BLOCK_N1_M2_OPTIONS: m2 = n1 @@ -388,7 +431,6 @@ def get_bwd_configs(autotune: bool): "BLOCK_N2": n2, "BLK_SLICE_FACTOR": blk_slice, "waves_per_eu": waves, - "matrix_instr_nonkdim": matrix_instr_nonkdim, }, num_stages=num_stages, num_warps=num_warps, @@ -407,7 +449,7 @@ def get_bwd_configs(autotune: bool): (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys), -) = get_bwd_configs(AUTOTUNE) +) = get_bwd_configs(False) @triton.jit From d03c5a02796481e8a22d21e9b3a90bc6db2f3a4c Mon Sep 17 00:00:00 2001 From: Michael Date: Sat, 11 Oct 2025 08:42:41 -0500 Subject: [PATCH 26/33] tune non causal path --- flash_attn/flash_attn_triton_amd/bwd.py | 31 +++++++++++++++++++------ 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 18c4eb96b9b..1c2331cb5ca 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -249,6 +249,9 @@ def get_bwd_configs(autotune: bool): triton.Config( {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=1, num_warps=8 ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=4 + ), ] noncausal_autotune_configs = [ triton.Config( @@ -274,7 +277,19 @@ def get_bwd_configs(autotune: bool): }, num_stages=1, num_warps=4, - ) + ), + triton.Config( + { + "BLOCK_M1": 16, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 2, + }, + num_stages=1, + num_warps=4, + ), ] causal_autotune_configs = [ triton.Config( @@ -352,18 +367,21 @@ def get_bwd_configs(autotune: bool): NUM_STAGES_OPTIONS = [1, 2] # og: 1 NUM_WARPS_OPTIONS = [4, 8] # og: 4 WAVES_PER_EU_OPTIONS = [1, 2] # og: 1 + NON_CAUSAL_BLOCK_M1_OPTIONS = [16, 32, 64] # og: 32 + NON_CAUSAL_BLOCK_N1_M2_OPTIONS = [64, 128, 256] # og: 128 + NON_CAUSAL_BLOCK_N2_OPTIONS = [16, 32, 64] # og: 32 CAUSAL_BLOCK_M1_OPTIONS = [ # og: 32 + 16, 32, 64, ] CAUSAL_BLOCK_N1_M2_OPTIONS = [64, 128, 256] # og: 128 - CAUSAL_BLOCK_N2_OPTIONS = [32, 64] # og: 32 - NON_CAUSAL_BLOCK_M1_OPTIONS = [32, 64] # og: 32 - NON_CAUSAL_BLOCK_N1_M2_OPTIONS = [64, 128, 256] # og: 128 - NON_CAUSAL_BLOCK_N2_OPTIONS = [32, 64] # og: 32 + CAUSAL_BLOCK_N2_OPTIONS = [16, 32, 64] # og: 32 BLK_SLICE_FACTOR_OPTIONS = [2] # og: 2 # ==================== sweep configs ================================ + os.environ["TRITON_PRINT_AUTOTUNING"] = "1" + preprocess_autotune_configs = [] for pre_num_warps in PRE_NUM_WARPS_OPTIONS: for pre_num_stages in PRE_NUM_STAGES_OPTIONS: @@ -444,12 +462,11 @@ def get_bwd_configs(autotune: bool): ) -# os.environ["TRITON_PRINT_AUTOTUNING"] = "1" ( (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys), -) = get_bwd_configs(False) +) = get_bwd_configs(True) @triton.jit From 73755aa9c96edfd1add08834e4ff94b32a5d81ca Mon Sep 17 00:00:00 2001 From: Michael Date: Sat, 11 Oct 2025 08:45:56 -0500 Subject: [PATCH 27/33] fix bug --- flash_attn/flash_attn_triton_amd/bwd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 1c2331cb5ca..05e40080d00 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -3063,8 +3063,8 @@ def bwd_kernel_fused_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_ + offs_d_v[None, :] * stride_vd ) # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_k, other=0.0) - v = tl.load(V + adj_v, mask=mask_v, other=0.0) + k = tl.load(K + adj_k, mask=mask_k) + v = tl.load(V + adj_v, mask=mask_v) # If MQA / GQA, set the K and V head offsets appropriately. # hqid = hkid for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): @@ -3615,8 +3615,8 @@ def bwd_kernel_fused_noncausal( + offs_d_v[None, :] * stride_vd ) # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_k, other=0.0) - v = tl.load(V + adj_v, mask=mask_v, other=0.0) + k = tl.load(K + adj_k, mask=mask_k) + v = tl.load(V + adj_v, mask=mask_v) # If MQA / GQA, set the K and V head offsets appropriately. for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): # offset input and output tensor by batch and Q/K heads From 6e36a8224ce08f2f91a4fa357882028c802315b3 Mon Sep 17 00:00:00 2001 From: Michael Date: Sat, 11 Oct 2025 10:39:19 -0500 Subject: [PATCH 28/33] save --- flash_attn/flash_attn_triton_amd/bwd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 05e40080d00..2e4f9ab7dee 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -371,12 +371,12 @@ def get_bwd_configs(autotune: bool): NON_CAUSAL_BLOCK_N1_M2_OPTIONS = [64, 128, 256] # og: 128 NON_CAUSAL_BLOCK_N2_OPTIONS = [16, 32, 64] # og: 32 CAUSAL_BLOCK_M1_OPTIONS = [ # og: 32 - 16, 32, 64, + 128 ] CAUSAL_BLOCK_N1_M2_OPTIONS = [64, 128, 256] # og: 128 - CAUSAL_BLOCK_N2_OPTIONS = [16, 32, 64] # og: 32 + CAUSAL_BLOCK_N2_OPTIONS = [32, 64, 128] # og: 32 BLK_SLICE_FACTOR_OPTIONS = [2] # og: 2 # ==================== sweep configs ================================ From 3b144619d4c39c975158f9a714bf82f70a2b3656 Mon Sep 17 00:00:00 2001 From: Michael Date: Sat, 11 Oct 2025 10:44:31 -0500 Subject: [PATCH 29/33] Skip configs where BLOCK_M2 % BLOCK_N2 != 0 --- flash_attn/flash_attn_triton_amd/bwd.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 2e4f9ab7dee..688647b7a97 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -410,6 +410,10 @@ def get_bwd_configs(autotune: bool): assert ( n1 == m2 ), f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + + # Skip configs where BLOCK_M2 % BLOCK_N2 != 0 + if m2 % n2 != 0: + continue for blk_slice in BLK_SLICE_FACTOR_OPTIONS: causal_autotune_configs.append( @@ -439,6 +443,11 @@ def get_bwd_configs(autotune: bool): assert ( n1 == m2 ), f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + + # Skip configs where BLOCK_M2 % BLOCK_N2 != 0 + if m2 % n2 != 0: + continue + for blk_slice in BLK_SLICE_FACTOR_OPTIONS: noncausal_autotune_configs.append( triton.Config( From 75aa2911c5a71f81a84c4c43e5c4503cf1372744 Mon Sep 17 00:00:00 2001 From: Michael Date: Sat, 11 Oct 2025 10:53:43 -0500 Subject: [PATCH 30/33] skip more --- flash_attn/flash_attn_triton_amd/bwd.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 688647b7a97..7bac38b4cd8 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -414,6 +414,10 @@ def get_bwd_configs(autotune: bool): # Skip configs where BLOCK_M2 % BLOCK_N2 != 0 if m2 % n2 != 0: continue + + # Skip configs where BLOCK_N1 % BLOCK_M1 != 0 + if n1 % m1 != 0: + continue for blk_slice in BLK_SLICE_FACTOR_OPTIONS: causal_autotune_configs.append( @@ -447,6 +451,10 @@ def get_bwd_configs(autotune: bool): # Skip configs where BLOCK_M2 % BLOCK_N2 != 0 if m2 % n2 != 0: continue + + # Skip configs where BLOCK_N1 % BLOCK_M1 != 0 + if n1 % m1 != 0: + continue for blk_slice in BLK_SLICE_FACTOR_OPTIONS: noncausal_autotune_configs.append( From fb53a001aa14183a49719b0c7a812a3e10f1fa56 Mon Sep 17 00:00:00 2001 From: Michael Date: Sat, 11 Oct 2025 19:17:43 -0500 Subject: [PATCH 31/33] stop tuning --- flash_attn/flash_attn_triton_amd/bwd.py | 45 ++++++++++++++++++------- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 7bac38b4cd8..4bedeb947a5 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -266,6 +266,18 @@ def get_bwd_configs(autotune: bool): num_stages=1, num_warps=4, ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 128, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), triton.Config( { "BLOCK_M1": 64, @@ -304,6 +316,18 @@ def get_bwd_configs(autotune: bool): num_stages=1, num_warps=4, ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), ] else: preprocess_autotune_configs = [ @@ -367,21 +391,18 @@ def get_bwd_configs(autotune: bool): NUM_STAGES_OPTIONS = [1, 2] # og: 1 NUM_WARPS_OPTIONS = [4, 8] # og: 4 WAVES_PER_EU_OPTIONS = [1, 2] # og: 1 - NON_CAUSAL_BLOCK_M1_OPTIONS = [16, 32, 64] # og: 32 - NON_CAUSAL_BLOCK_N1_M2_OPTIONS = [64, 128, 256] # og: 128 - NON_CAUSAL_BLOCK_N2_OPTIONS = [16, 32, 64] # og: 32 + NON_CAUSAL_BLOCK_M1_OPTIONS = [16, 32, 64, 128] # og: 32 + NON_CAUSAL_BLOCK_N1_M2_OPTIONS = [32, 64, 128, 256] # og: 128 + NON_CAUSAL_BLOCK_N2_OPTIONS = [16, 32, 64, 128] # og: 32 CAUSAL_BLOCK_M1_OPTIONS = [ # og: 32 32, - 64, - 128 + 64 ] - CAUSAL_BLOCK_N1_M2_OPTIONS = [64, 128, 256] # og: 128 - CAUSAL_BLOCK_N2_OPTIONS = [32, 64, 128] # og: 32 + CAUSAL_BLOCK_N1_M2_OPTIONS = [32, 64, 128] # og: 128 + CAUSAL_BLOCK_N2_OPTIONS = [32, 64] # og: 32 BLK_SLICE_FACTOR_OPTIONS = [2] # og: 2 - # ==================== sweep configs ================================ - os.environ["TRITON_PRINT_AUTOTUNING"] = "1" - + # ==================== sweep configs ================================ preprocess_autotune_configs = [] for pre_num_warps in PRE_NUM_WARPS_OPTIONS: for pre_num_stages in PRE_NUM_STAGES_OPTIONS: @@ -478,12 +499,12 @@ def get_bwd_configs(autotune: bool): (noncausal_autotune_configs, noncausal_autotune_keys), ) - +# os.environ["TRITON_PRINT_AUTOTUNING"] = "1" ( (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys), -) = get_bwd_configs(True) +) = get_bwd_configs(AUTOTUNE) @triton.jit From 0c17f2f7a16957384a729e07e341d505992a02a3 Mon Sep 17 00:00:00 2001 From: Michael Date: Sun, 12 Oct 2025 09:49:52 -0500 Subject: [PATCH 32/33] fix varlen bug --- flash_attn/flash_attn_triton_amd/fwd_prefill.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index b976b0a1cac..a997abcc841 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -1914,9 +1914,6 @@ def attention_forward_prefill_triton_impl( assert ( sd_mask.shape[3] >= max_seqlens_k ), f"sd_mask.shape[3]={sd_mask.shape[3]} must be >= max_seqlens_k={max_seqlens_k}" - assert ( - sd_mask.dtype == torch.float32 - ), f"sd_mask must be float32, got {sd_mask.dtype}" assert sd_mask.device == q.device, f"sd_mask must be on same device as q" if DROPOUT_USE_PYTORCH: From 1d670651c0a35e3e23f91743a7d301a86ee41109 Mon Sep 17 00:00:00 2001 From: Michael Date: Sun, 12 Oct 2025 19:55:36 -0500 Subject: [PATCH 33/33] fix dropout & causal/swa segfault --- flash_attn/flash_attn_triton_amd/bwd.py | 42 +-- .../flash_attn_triton_amd/fwd_prefill.py | 254 +++++++++--------- .../flash_attn_triton_amd/interface_v2.py | 62 +++-- flash_attn/flash_attn_triton_amd/utils.py | 16 +- 4 files changed, 172 insertions(+), 202 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 4bedeb947a5..d2ed7aa113a 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -7,22 +7,13 @@ from .utils import ( DEBUG, AUTOTUNE, - DROPOUT_USE_PYTORCH, - DROPOUT_DUMP, compute_fp8_scaling_factors, - create_dropout_mask, - create_dropout_mask_varlen, get_cu_count, is_cdna, is_fp8, get_arch, ) -# NOTE: triton fails to import tl.constexprs so create them here for the file -tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) -tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) - - def get_bwd_configs(autotune: bool): # keys preprocess_autotune_keys = [ @@ -2658,15 +2649,8 @@ def _bwd_dkdv_inner( + offs_m[None, :] * stride_dropoutm + offs_n[:, None] * stride_dropoutn ) - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = ( - offs_m[None, :] * stride_dropoutm - + offs_n[:, None] * stride_dropoutn - ) - dropout_mask = tl.load(curr_dropout_offset + dropout_offs, mask=mask_nm) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p dropout_scale = 1.0 / (1 - dropout_p) # Load m before computing qk to reduce pipeline stall. m = tl.load(M + offs_m * stride_lse_m, mask=mask_m, other=0.0) @@ -2849,15 +2833,8 @@ def _bwd_dq_inner( + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn ) - if tl_DROPOUT_USE_PYTORCH: - dropout_offs = ( - offs_m[:, None] * stride_dropoutm - + offs_n[None, :] * stride_dropoutn - ) - dropout_mask = tl.load(curr_dropout_offset + dropout_offs, mask=mask_mn) - else: - rand_vals = tl.rand(philox_seed, philox_offs) - dropout_mask = rand_vals > dropout_p + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p dropout_scale = 1 / (1 - dropout_p) if IS_FP8: @@ -4242,17 +4219,6 @@ def attention_backward_triton_impl( dtype=torch.float32, ) - if DROPOUT_USE_PYTORCH: - if not IS_VARLEN: - dropout_mask = create_dropout_mask( - dropout_p, - (batch, nheads_q, max_seqlen_q, max_seqlen_k), - seed=philox_seed, - ) - else: - dropout_mask = create_dropout_mask_varlen( - dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed - ) stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = ( dropout_mask.stride() ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index a997abcc841..81a1de19f20 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -7,8 +7,6 @@ from .utils import ( DEBUG, AUTOTUNE, - DROPOUT_USE_PYTORCH, - DROPOUT_DUMP, compute_alibi_block, compute_fp8_scaling_factors, get_arch, @@ -16,14 +14,10 @@ is_cdna, is_fp8, is_rdna, - create_dropout_mask, apply_rotary, get_recommended_fp8_dtype, ) -# NOTE: triton fails to import tl.constexprs so create them here for the file -tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) -tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) def get_fwd_configs(autotune: bool): @@ -178,14 +172,18 @@ def _attn_fwd_no_mask( stride_vk, stride_bn, stride_sn, + stride_sm, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, - philox_base_ptrs, - sd_mask_base_ptrs, - dropout_mask_base_ptrs, + philox_offset_base, + sd_mask, + stride_sz, + stride_sh, + off_z, + off_h_q, offs_m, offs_n, offs_d_qk, @@ -284,29 +282,34 @@ def _attn_fwd_no_mask( # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - dropout_mask_ptrs = dropout_mask_base_ptrs + start_n * stride_sn - sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn - philox_ptrs = philox_base_ptrs + start_n * stride_sn - if tl_DROPOUT_USE_PYTORCH: - dropout_mask = tl.load(dropout_mask_ptrs, mask=qk_mask) - else: - rng_output = tl.rand( - philox_seed, philox_ptrs - ) # TODO: use tl.randint for better performance - dropout_mask = rng_output > dropout_p - if tl_DROPOUT_DUMP: - tl.store(dropout_mask_ptrs, dropout_mask, mask=qk_mask) - - # return scores with negative values for dropped vals - sd_mask = tl.where(dropout_mask, p, -p) - tl.store(sd_mask_ptrs, sd_mask, mask=qk_mask) + # Compute pointers for this block + philox_base = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh + philox_ptrs = philox_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + # compute dropout mask + rng_output = tl.rand(philox_seed, philox_ptrs) + dropout_mask = rng_output > dropout_p + + # return scores with negative values for dropped vals (only if RETURN_SCORES is True) + if RETURN_SCORES: + sd_mask_value = tl.where(dropout_mask, p, -p) + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + # Compute mask for sd_mask storage + sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + tl.store(sd_mask_ptrs, sd_mask_value, mask=sd_store_mask) # apply dropout mask in place p = tl.where(dropout_mask, p, 0.0) elif RETURN_SCORES: # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn - tl.store(sd_mask_ptrs, p, mask=qk_mask) + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + # Compute mask for sd_mask storage + sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + tl.store(sd_mask_ptrs, p, mask=sd_store_mask) # -- update output accumulator -- # alpha is an adjustment factor for acc and li as we loop and find new maxes @@ -357,14 +360,18 @@ def _attn_fwd_mask( stride_vk, stride_bn, stride_sn, + stride_sm, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, - philox_base_ptrs, - sd_mask_base_ptrs, - dropout_mask_base_ptrs, + philox_offset_base, + sd_mask, + stride_sz, + stride_sh, + off_z, + off_h_q, offs_m, offs_n, offs_d_qk, @@ -586,29 +593,74 @@ def _attn_fwd_mask( # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - dropout_mask_ptrs = dropout_mask_base_ptrs + start_n * stride_sn - sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn - philox_ptrs = philox_base_ptrs + start_n * stride_sn - if tl_DROPOUT_USE_PYTORCH: - dropout_mask = tl.load(dropout_mask_ptrs, mask=qk_mask) - else: - rng_output = tl.rand( - philox_seed, philox_ptrs - ) # TODO: use tl.randint for better performance - dropout_mask = rng_output > dropout_p - if tl_DROPOUT_DUMP: - tl.store(dropout_mask_ptrs, dropout_mask, mask=qk_mask) - - # return scores with negative values for dropped vals - sd_mask = tl.where(dropout_mask, p, -p) - tl.store(sd_mask_ptrs, sd_mask, mask=qk_mask) + # Compute pointers for this block + philox_base = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh + philox_ptrs = philox_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + # compute dropout mask + rng_output = tl.rand(philox_seed, philox_ptrs) + dropout_mask = rng_output > dropout_p + + # return scores with negative values for dropped vals (only if RETURN_SCORES is True) + if RETURN_SCORES: + sd_mask_value = tl.where(dropout_mask, p, -p) + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + # Compute mask for sd_mask storage - include bounds check + sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + + # Add causal mask if applicable to prevent writing to invalid positions + if IS_CAUSAL: + seqlen_delta_qk = seqlen_k - seqlen_q + causal_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk) + sd_store_mask = sd_store_mask & causal_constraint + + # Add sliding window mask if applicable + if USE_SLIDING_WINDOW: + seqlen_delta_qk = seqlen_k - seqlen_q + if WINDOW_SIZE_LEFT < 0: + # Only right window constraint + window_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT) + else: + # Both left and right window constraints + left_bound = offs_m[:, None] + seqlen_delta_qk - WINDOW_SIZE_LEFT + right_bound = offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT + window_constraint = (kv_offs_n[None, :] >= left_bound) & (kv_offs_n[None, :] <= right_bound) + sd_store_mask = sd_store_mask & window_constraint + + tl.store(sd_mask_ptrs, sd_mask_value, mask=sd_store_mask) # apply dropout mask in place p = tl.where(dropout_mask, p, 0.0) elif RETURN_SCORES: # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - sd_mask_ptrs = sd_mask_base_ptrs + start_n * stride_sn - tl.store(sd_mask_ptrs, p, mask=qk_mask) + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = sd_mask_base + offs_m[:, None] * stride_sm + kv_offs_n[None, :] * stride_sn + + # Compute mask for sd_mask storage - include bounds check + sd_store_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + + # Add causal mask if applicable + if IS_CAUSAL: + seqlen_delta_qk = seqlen_k - seqlen_q + causal_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk) + sd_store_mask = sd_store_mask & causal_constraint + + # Add sliding window mask if applicable + if USE_SLIDING_WINDOW: + seqlen_delta_qk = seqlen_k - seqlen_q + if WINDOW_SIZE_LEFT < 0: + # Only right window constraint + window_constraint = kv_offs_n[None, :] <= (offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT) + else: + # Both left and right window constraints + left_bound = offs_m[:, None] + seqlen_delta_qk - WINDOW_SIZE_LEFT + right_bound = offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT + window_constraint = (kv_offs_n[None, :] >= left_bound) & (kv_offs_n[None, :] <= right_bound) + sd_store_mask = sd_store_mask & window_constraint + + tl.store(sd_mask_ptrs, p, mask=sd_store_mask) # -- update output accumulator -- # alpha is an adjustment factor for acc and li as we loop and find new maxes @@ -951,6 +1003,8 @@ def attn_fwd( stride_v_descale_z, LSE, Out, + SD_MASK, + ALIBI_SLOPES, stride_qz, stride_qh, stride_qm, @@ -987,9 +1041,6 @@ def attn_fwd( dropout_p, philox_seed, philox_offset_base, - sd_mask, - dropout_mask, - alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, @@ -1010,7 +1061,6 @@ def attn_fwd( USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, - NEEDS_SDMASK: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, IS_FP8: tl.constexpr, @@ -1190,41 +1240,10 @@ def attn_fwd( if USE_ALIBI: a_offset = off_z * stride_az + off_h_q * stride_ah - alibi_slope = tl.load(alibi_slopes + a_offset) + alibi_slope = tl.load(ALIBI_SLOPES + a_offset) else: alibi_slope = None - if NEEDS_SDMASK: - sd_mask_offset = ( - sd_mask + off_z * stride_sz + off_h_q * stride_sh - ) # + cu_seqlens_q_start * stride_sm - sd_mask_ptrs = ( - sd_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn - ) - else: - sd_mask_ptrs = None - - if ENABLE_DROPOUT: - dropout_mask_offset = ( - dropout_mask + off_z * stride_sz + off_h_q * stride_sh - ) # + cu_seqlens_q_start * stride_sm - dropout_mask_ptrs = ( - dropout_mask_offset - + offs_m[:, None] * stride_sm - + offs_n[None, :] * stride_sn - ) - batch_philox_offset = ( - philox_offset_base + off_z * stride_sz + off_h_q * stride_sh - ) # + cu_seqlens_q_start * stride_sm - philox_ptrs = ( - batch_philox_offset - + offs_m[:, None] * stride_sm - + offs_n[None, :] * stride_sn - ) - else: - dropout_mask_ptrs = None - philox_ptrs = 0 - # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=ACCUMULATOR_TYPE) l_i = tl.full([BLOCK_M], 1.0, dtype=ACCUMULATOR_TYPE) @@ -1254,14 +1273,18 @@ def attn_fwd( stride_vk, stride_bn, stride_sn, + stride_sm, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, - philox_ptrs, - sd_mask_ptrs, - dropout_mask_ptrs, + philox_offset_base, + SD_MASK, + stride_sz, + stride_sh, + off_z, + off_h_q, offs_m, offs_n, offs_d_qk, @@ -1316,14 +1339,18 @@ def attn_fwd( stride_vk, stride_bn, stride_sn, + stride_sm, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, - philox_ptrs, - sd_mask_ptrs, - dropout_mask_ptrs, + philox_offset_base, + SD_MASK, + stride_sz, + stride_sh, + off_z, + off_h_q, offs_m, offs_n, offs_d_qk, @@ -1378,14 +1405,18 @@ def attn_fwd( stride_vk, stride_bn, stride_sn, + stride_sm, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, - philox_ptrs, - sd_mask_ptrs, - dropout_mask_ptrs, + philox_offset_base, + SD_MASK, + stride_sz, + stride_sh, + off_z, + off_h_q, offs_m, offs_n, offs_d_qk, @@ -1566,7 +1597,7 @@ def attention_forward_prefill_triton_impl( philox_seed: Optional[int], philox_offset: Optional[int], # misc - return_softmax: bool, + return_scores: bool, use_exp2: bool, # fp8 q_descale: Optional[torch.Tensor], @@ -1892,15 +1923,12 @@ def attention_forward_prefill_triton_impl( padded_d_model_qk = max(padded_d_model_qk, 16) padded_d_model_v = max(padded_d_model_v, 16) - # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out - # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according - # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing - # only. This return holds no useful output aside from debugging. - NEEDS_SDMASK = (dropout_p > 0.0) or return_softmax - if NEEDS_SDMASK: + # sd_mask assertions and strides + if sd_mask is not None: + assert dropout_p > 0.0 or return_scores, "sd_mask provided but not used" assert ( sd_mask is not None - ), "sd_mask must be provided when return_softmax=True or dropout_p > 0" + ), "sd_mask must be provided when return_scores=True or dropout_p > 0" # Assert sd_mask tensor is large enough assert ( sd_mask.shape[0] >= batch @@ -1916,18 +1944,6 @@ def attention_forward_prefill_triton_impl( ), f"sd_mask.shape[3]={sd_mask.shape[3]} must be >= max_seqlens_k={max_seqlens_k}" assert sd_mask.device == q.device, f"sd_mask must be on same device as q" - if DROPOUT_USE_PYTORCH: - dropout_mask = create_dropout_mask( - dropout_p, - (batch, nheads_q, max_seqlens_q, max_seqlens_k), - seed=philox_seed, - ) - else: - dropout_mask = torch.zeros( - (batch, nheads_q, max_seqlens_q, max_seqlens_k), - device=q.device, - dtype=torch.float32, - ) stride_sz, stride_sh, stride_sm, stride_sn = ( sd_mask.stride(0), sd_mask.stride(1), @@ -1935,8 +1951,6 @@ def attention_forward_prefill_triton_impl( sd_mask.stride(3), ) else: - sd_mask = None - dropout_mask = None stride_sz, stride_sh, stride_sm, stride_sn = (0, 0, 0, 0) if bias is not None: @@ -1964,6 +1978,8 @@ def attention_forward_prefill_triton_impl( stride_v_descale_z, softmax_lse, o, + sd_mask, + alibi_slopes, stride_qb, stride_qh, stride_qm, @@ -2000,9 +2016,6 @@ def attention_forward_prefill_triton_impl( dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, - sd_mask=sd_mask, - dropout_mask=dropout_mask, - alibi_slopes=alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL_QK=head_size_qk, @@ -2021,10 +2034,9 @@ def attention_forward_prefill_triton_impl( USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p > 0.0, USE_EXP2=use_exp2, - RETURN_SCORES=return_softmax, - NEEDS_SDMASK=NEEDS_SDMASK, + RETURN_SCORES=return_scores, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_P_DESCALE=False, USE_SEQUSED=(seqused_q is not None or seqused_k is not None), - ) # Add flag for seqused + ) diff --git a/flash_attn/flash_attn_triton_amd/interface_v2.py b/flash_attn/flash_attn_triton_amd/interface_v2.py index dfa3c4fdae6..d303ba63e7a 100644 --- a/flash_attn/flash_attn_triton_amd/interface_v2.py +++ b/flash_attn/flash_attn_triton_amd/interface_v2.py @@ -46,10 +46,10 @@ def fwd( if DEBUG: print() print("flash_attn_triton_amd.py::fwd inputs") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("out:", out, out.shape if out is not None else None) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out.shape if out is not None else None) print("alibi_slopes:", alibi_slopes) print("dropout_p:", dropout_p) print("softmax_scale:", softmax_scale) @@ -58,7 +58,11 @@ def fwd( print("window_size_right:", window_size_right) print("softcap:", softcap) print("return_softmax:", return_softmax) - out = torch.zeros_like(q) if out is None else out.zero_() + + if out is None: + out = torch.zeros_like(q) + else: + out.zero_() # Layout / shapes layout = "bshd" @@ -92,7 +96,7 @@ def fwd( device=q.device, dtype=torch.float32, ) - if return_softmax: + if dropout_p > 0.0 or return_softmax: sd_mask = torch.zeros( ( batch, @@ -111,7 +115,7 @@ def fwd( device=q.device, dtype=torch.float32, ) - if return_softmax: + if dropout_p > 0.0 or return_softmax: sd_mask = torch.zeros( (batch, nheads_q, max_seqlen_q, max_seqlen_k), device=q.device, @@ -157,9 +161,9 @@ def fwd( if DEBUG: print("flash_attn_triton_amd.py::fwd outputs") - print("o:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None) + print("o:", out.shape if out is not None else None) + print("softmax_lse:", softmax_lse.shape if softmax_lse is not None else None) + print("sd_mask:", sd_mask.shape if sd_mask is not None else None) print("rng_state:", rng_state) # --- Assertions (shape + dtype contracts) --- @@ -232,14 +236,14 @@ def bwd( print() print("flash_attn_triton_amd.py::bwd inputs") print("dout:", dout, dout.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("out:", out, out.shape) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape if dq is not None else None) - print("dk:", dk, dk.shape if dk is not None else None) - print("dv:", dv, dv.shape if dv is not None else None) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out.shape) + print("softmax_lse:", softmax_lse.shape) + print("dq:", dq.shape if dq is not None else None) + print("dk:", dk.shape if dk is not None else None) + print("dv:", dv.shape if dv is not None else None) print("alibi_slopes:", alibi_slopes) print("dropout_p:", dropout_p) print("out:", out) @@ -385,9 +389,9 @@ def varlen_fwd( if DEBUG: print() print("flash_attn_triton_amd.py::varlen_fwd") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) print("alibi_slopes:", alibi_slopes) @@ -566,15 +570,15 @@ def varlen_bwd( if DEBUG: print() print("varlen_bwd") - print("dout:", dout, dout.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) + print("dout:", dout.shape) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) print("out:", out) - print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape if dq is not None else None) - print("dk:", dk, dk.shape if dk is not None else None) - print("dv:", dv, dv.shape if dv is not None else None) + print("softmax_lse:", softmax_lse.shape) + print("dq:", dq.shape if dq is not None else None) + print("dk:", dk.shape if dk is not None else None) + print("dv:", dv.shape if dv is not None else None) print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) print("alibi_slopes:", alibi_slopes) diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index b6eb5ed025a..d0a20eb6aa4 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -17,23 +17,13 @@ "true", "yes", ) -if AUTOTUNE: - os.environ["TRITON_PRINT_AUTOTUNING"] = "1" DEBUG = os.environ.get("FLASH_ATTENTION_TRITON_AMD_DEBUG", "0").lower() in ( "1", "true", "yes", ) -PERF = os.environ.get("FLASH_ATTENTION_TRITON_AMD_PERF", "0").lower() in ( - "1", - "true", - "yes", -) -USE_SINGLE_BWD_KERNEL = os.environ.get("USE_SINGLE_BWD_KERNEL", "0").lower() in ( - "1", - "true", - "yes", -) +if AUTOTUNE or DEBUG: + os.environ["TRITON_PRINT_AUTOTUNING"] = "1" USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" USE_TRITON_INTERPRET = os.environ.get("TRITON_INTERPRET", "0").lower() in ( "1", @@ -51,8 +41,6 @@ if USE_TRITON_ROCM: # TODO remove this random.seed(42) BWD_MODE: Literal["fused", "fused_atomic", "split"] = "fused" -DROPOUT_USE_PYTORCH = False -DROPOUT_DUMP = False USE_EXP2 = True PHILOX_SEED = 0x1BF58 PHILOX_OFFSET = 0x1D4B49