From 4f1d8bb50c337e296468ba4e30a93fd3f29882b4 Mon Sep 17 00:00:00 2001 From: Tianwei Yang Date: Tue, 6 Jan 2026 02:35:07 +0000 Subject: [PATCH 1/6] [Navi]add more triton config --- .../flash_attn_triton_amd/fwd_decode.py | 27 ++++++++++++++- .../flash_attn_triton_amd/fwd_prefill.py | 34 ++++++++++--------- 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index 3f2d92c22d6..5c16cf4c552 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -2,7 +2,7 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna +from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna, is_rdna def get_cdna_autotune_configs(): return [ @@ -23,6 +23,26 @@ def get_cdna_autotune_configs(): num_warps=4), ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] +def get_rdna_autotune_configs(): + return [ + # Most aggressive - 128x128 (best for large sequences) + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # Large blocks + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + # Medium blocks + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] + def get_autotune_configs(): if AUTOTUNE: if is_cdna(): @@ -30,6 +50,11 @@ def get_autotune_configs(): fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) + elif is_rdna(): + autotune_configs, autotune_keys = get_rdna_autotune_configs() + fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys + reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys + return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) else: raise ValueError("Unknown Device Type") else: diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 6f69cd02813..7e814be4118 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -186,24 +186,25 @@ def get_cdna_autotune_configs(): def get_rdna_autotune_configs(): return [ - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), + # === Configs for head_dim=64 (optimal) === + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # === Configs for head_dim=128 (Wan2.2) - smaller blocks to reduce register pressure === + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # === General fallback configs === + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] + def get_autotune_configs(): if AUTOTUNE: if is_rdna(): @@ -214,8 +215,9 @@ def get_autotune_configs(): raise ValueError("Unknown Device Type") else: return [ + # Use BLOCK_N=32 to avoid register spilling on gfx1100 triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + {"BLOCK_M": 64, "BLOCK_N": 32, "waves_per_eu": 1, "PRE_LOAD_V": True}, num_stages=1, num_warps=4, ), From 929a4bbc89cb52856b0400567849344e5a64793e Mon Sep 17 00:00:00 2001 From: Tianwei Yang Date: Tue, 6 Jan 2026 02:35:44 +0000 Subject: [PATCH 2/6] [Navi]enable exp2 by default --- flash_attn/flash_attn_triton_amd/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 5d3bf02e1f8..167b99e0d81 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -48,7 +48,7 @@ class MetaData(): philox_seed: Optional[int] = None philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. - use_exp2: bool = False + use_exp2: bool = True rotary_sin: Optional[torch.Tensor] = None rotary_cos: Optional[torch.Tensor] = None rotary_interleaved: bool = False From dc8b05da38ebed93060beff071676b0d326d3f3f Mon Sep 17 00:00:00 2001 From: Tianwei Yang Date: Tue, 6 Jan 2026 02:36:39 +0000 Subject: [PATCH 3/6] [Navi]Add support for arch gfx1100 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f0b476255ba..9b1bd10088a 100644 --- a/setup.py +++ b/setup.py @@ -197,7 +197,7 @@ def rename_cpp_to_cu(cpp_files): def validate_and_update_archs(archs): # List of allowed architectures - allowed_archs = ["native", "gfx90a", "gfx950", "gfx942"] + allowed_archs = ["native", "gfx90a", "gfx950", "gfx942", "gfx1100"] # Validate if each element in archs is in allowed_archs assert all( From 64c5924c5ac7a73fa5747e9759d410b325b95a2d Mon Sep 17 00:00:00 2001 From: Tianwei Yang Date: Tue, 6 Jan 2026 02:38:03 +0000 Subject: [PATCH 4/6] [ROCM]warp fa to support L2 cache aware to improve performance --- .../flash_attn_triton_amd/interface_fa.py | 233 ++++++++++++++- .../flash_attn_triton_amd/l2_cache_aware.py | 276 ++++++++++++++++++ 2 files changed, 507 insertions(+), 2 deletions(-) create mode 100644 flash_attn/flash_attn_triton_amd/l2_cache_aware.py diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 06ab7d24d56..c223ee93b6c 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -9,11 +9,17 @@ from .fwd_ref import attention_forward_pytorch_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl from .utils import DEBUG, USE_REF, MetaData, get_shapes_from_layout, is_fp8 +from .l2_cache_aware import is_head_grouping_beneficial, print_head_grouping_info from einops import rearrange, repeat from flash_attn.layers.rotary import apply_rotary_emb from typing import Literal, Optional, Union -def fwd(q: torch.Tensor, +# Environment variable to enable verbose head grouping output +L2_HEAD_GROUPING_DEBUG = os.environ.get('FLASH_ATTN_HEAD_GROUPING_DEBUG', '0') == '1' + + + +def _fwd_single_group(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: Optional[torch.Tensor], @@ -31,6 +37,7 @@ def fwd(q: torch.Tensor, descale_v: Optional[torch.Tensor] = None, descale_o: Optional[torch.Tensor] = None ): + """Original fwd implementation for a single head group.""" if DEBUG: print() @@ -145,6 +152,112 @@ def fwd(q: torch.Tensor, return out, softmax_lse, sd_mask, rng_state + +def fwd(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None + ): + """ + Flash attention forward with L2 cache-aware head grouping. + + For consumer AMD GPUs (e.g., gfx1100 with 96MB L2), when K,V tensors + exceed L2 cache capacity, processing heads in groups that fit in L2 + can provide up to 2x speedup. + + Layout: bshd (batch, seqlen, heads, head_dim) + """ + # Get shapes for head grouping decision + # Layout is bshd: [batch, seqlen, heads, head_dim] + batch, seqlen_q, nheads_q, head_dim = q.shape + seqlen_k = k.shape[1] + nheads_k = k.shape[2] + + # Check if head grouping is beneficial + should_group, group_size = is_head_grouping_beneficial( + nheads_q, seqlen_k, head_dim, q.dtype, q.device.index or 0 + ) + + if L2_HEAD_GROUPING_DEBUG: + print_head_grouping_info(nheads_q, seqlen_k, head_dim, q.dtype, q.device.index or 0) + + if not should_group or group_size >= nheads_q: + # No grouping needed - use original implementation + return _fwd_single_group( + q, k, v, out, alibi_slopes, dropout_p, softmax_scale, causal, + window_size_left, window_size_right, softcap, return_softmax, + gen_, descale_q, descale_k, descale_v, descale_o + ) + + # Process heads in groups for L2 cache efficiency + if L2_HEAD_GROUPING_DEBUG: + print(f"[L2 Head Grouping] Processing {nheads_q} heads in groups of {group_size}") + + # Prepare output tensor + if is_fp8(q): + assert out is not None, "fp8 output tensor should be passed in." + else: + out = torch.zeros_like(q) if out is None else out.zero_() + + # Collect outputs for each group + softmax_lse_list = [] + rng_state = None + + n_groups = (nheads_q + group_size - 1) // group_size + + for g in range(n_groups): + start_h = g * group_size + end_h = min((g + 1) * group_size, nheads_q) + + # Slice heads: bshd layout -> select heads on dim 2 + q_group = q[:, :, start_h:end_h, :].contiguous() + k_group = k[:, :, start_h:end_h, :].contiguous() + v_group = v[:, :, start_h:end_h, :].contiguous() + out_group = out[:, :, start_h:end_h, :].contiguous() + + # Handle alibi slopes if present + alibi_group = None + if alibi_slopes is not None: + alibi_group = alibi_slopes[start_h:end_h] if alibi_slopes.dim() == 1 else alibi_slopes[:, start_h:end_h] + + # Handle descale tensors for fp8 + descale_q_g = descale_q[:, start_h:end_h] if descale_q is not None else None + descale_k_g = descale_k[:, start_h:end_h] if descale_k is not None else None + descale_v_g = descale_v[:, start_h:end_h] if descale_v is not None else None + descale_o_g = descale_o[:, start_h:end_h] if descale_o is not None else None + + # Call the original implementation for this group + out_g, softmax_lse_g, sd_mask_g, rng_state = _fwd_single_group( + q_group, k_group, v_group, out_group, alibi_group, + dropout_p, softmax_scale, causal, + window_size_left, window_size_right, softcap, return_softmax, + gen_, descale_q_g, descale_k_g, descale_v_g, descale_o_g + ) + + # Copy output back to the main tensor + out[:, :, start_h:end_h, :] = out_g + softmax_lse_list.append(softmax_lse_g) + + # Concatenate softmax_lse across heads + softmax_lse = torch.cat(softmax_lse_list, dim=1) # Assuming lse is [batch, heads, ...] + + return out, softmax_lse, None, rng_state + + + BWD_MODE = os.environ.get('BWD_MODE', 'split').lower() def bwd( dout: torch.Tensor, @@ -349,7 +462,7 @@ def bwd( print("descale_dq:", descale_dq, descale_dq.shape if descale_dq is not None else None) return dq, dk, dv, delta -def varlen_fwd( +def _varlen_fwd_single_group( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -376,6 +489,7 @@ def varlen_fwd( descale_v: Optional[torch.Tensor] = None, descale_o: Optional[torch.Tensor] = None ): + """Original varlen_fwd implementation for a single head group.""" if DEBUG: print() @@ -490,6 +604,121 @@ def varlen_fwd( return out, softmax_lse, sd_mask, rng_state + +def varlen_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + block_table_: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool , + causal: bool , + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None + ): + """ + Variable-length flash attention forward with L2 cache-aware head grouping. + + For consumer AMD GPUs (e.g., gfx1100 with 96MB L2), when K,V tensors + exceed L2 cache capacity, processing heads in groups that fit in L2 + can provide up to 2x speedup. + + Layout: thd (total_seqlen, heads, head_dim) + """ + # Get shapes for head grouping decision + # Layout is thd: [total_seqlen, heads, head_dim] + total_seqlen, nheads_q, head_dim = q.shape + nheads_k = k.shape[1] + + # Check if head grouping is beneficial + should_group, group_size = is_head_grouping_beneficial( + nheads_q, max_seqlen_k, head_dim, q.dtype, q.device.index or 0 + ) + + if L2_HEAD_GROUPING_DEBUG: + print_head_grouping_info(nheads_q, max_seqlen_k, head_dim, q.dtype, q.device.index or 0) + + if not should_group or group_size >= nheads_q: + # No grouping needed - use original implementation + return _varlen_fwd_single_group( + q, k, v, out, cu_seqlens_q, cu_seqlens_k, seqused_k, leftpad_k, + block_table_, alibi_slopes, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, zero_tensors, causal, + window_size_left, window_size_right, softcap, return_softmax, + gen_, descale_q, descale_k, descale_v, descale_o + ) + + # Process heads in groups for L2 cache efficiency + if L2_HEAD_GROUPING_DEBUG: + print(f"[L2 Head Grouping varlen] Processing {nheads_q} heads in groups of {group_size}") + + # Prepare output tensor + if is_fp8(q): + assert out is not None, "fp8 output tensor should be passed in." + else: + out = torch.zeros_like(q) if out is None else out.zero_() + + # Collect outputs for each group + softmax_lse_list = [] + rng_state = None + + n_groups = (nheads_q + group_size - 1) // group_size + + for g in range(n_groups): + start_h = g * group_size + end_h = min((g + 1) * group_size, nheads_q) + + # Slice heads: thd layout -> select heads on dim 1 + q_group = q[:, start_h:end_h, :].contiguous() + k_group = k[:, start_h:end_h, :].contiguous() + v_group = v[:, start_h:end_h, :].contiguous() + out_group = out[:, start_h:end_h, :].contiguous() + + # Handle alibi slopes if present + alibi_group = None + if alibi_slopes is not None: + alibi_group = alibi_slopes[start_h:end_h] if alibi_slopes.dim() == 1 else alibi_slopes[:, start_h:end_h] + + # Handle descale tensors for fp8 + descale_q_g = descale_q[:, start_h:end_h] if descale_q is not None else None + descale_k_g = descale_k[:, start_h:end_h] if descale_k is not None else None + descale_v_g = descale_v[:, start_h:end_h] if descale_v is not None else None + descale_o_g = descale_o[:, start_h:end_h] if descale_o is not None else None + + # Call the original implementation for this group + out_g, softmax_lse_g, sd_mask_g, rng_state = _varlen_fwd_single_group( + q_group, k_group, v_group, out_group, cu_seqlens_q, cu_seqlens_k, + seqused_k, leftpad_k, block_table_, alibi_group, + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, + zero_tensors, causal, window_size_left, window_size_right, + softcap, return_softmax, gen_, descale_q_g, descale_k_g, descale_v_g, descale_o_g + ) + + # Copy output back to the main tensor + out[:, start_h:end_h, :] = out_g + softmax_lse_list.append(softmax_lse_g) + + # Concatenate softmax_lse across heads + softmax_lse = torch.cat(softmax_lse_list, dim=0) # varlen lse is [heads, total_seqlen] + + return out, softmax_lse, None, rng_state + def varlen_bwd( dout: torch.Tensor, q: torch.Tensor, diff --git a/flash_attn/flash_attn_triton_amd/l2_cache_aware.py b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py new file mode 100644 index 00000000000..7dd851b41e4 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py @@ -0,0 +1,276 @@ +""" +L2 Cache-Aware Head Grouping for Flash Attention + +This module provides functionality to optimize flash attention by processing +heads in groups that fit in the L2 cache. This is particularly important for +consumer AMD GPUs like gfx1100 (RX 7900 XTX) where the L2 cache is smaller +than datacenter GPUs. + +The key insight is that for large sequence lengths, the K and V tensors for +all heads may exceed L2 cache capacity, causing cache thrashing. By processing +heads in groups that fit in L2, we can achieve up to 2x speedup. + +Example: gfx1100 with 96MB L2, 40 heads, seqlen=17160, head_dim=128 +- K,V for all 40 heads = 352 MB (exceeds 96 MB L2) +- K,V for 10 heads = 88 MB (fits in 96 MB L2) +- Processing 10 heads at a time gives 1.95x speedup +""" + +import os +import functools +from typing import Optional, Tuple, Dict +import torch + +# L2 cache sizes for AMD GPUs in bytes +# Source: AMD documentation and hardware specs +AMD_L2_CACHE_SIZES: Dict[str, int] = { + # RDNA3 workstaion + "gfx1100": 96 * 1024 * 1024, # RX 7900 XTX/XT - 96 MB +} + +# Environment variable to override L2 cache size (in MB) +L2_CACHE_OVERRIDE_ENV = "FLASH_ATTN_L2_CACHE_MB" +# Environment variable to disable head grouping +DISABLE_HEAD_GROUPING_ENV = "FLASH_ATTN_DISABLE_HEAD_GROUPING" + +# Cached L2 size per device +_l2_cache_size_cache: Dict[int, int] = {} + + +@functools.lru_cache(maxsize=None) +def get_gcn_arch_name(device_index: int = 0) -> str: + """Get the GCN architecture name for an AMD GPU.""" + try: + props = torch.cuda.get_device_properties(device_index) + if hasattr(props, 'gcnArchName'): + return props.gcnArchName + # Fallback: try to get from name + name = props.name.lower() + if 'gfx' in name: + # Extract gfxXXXX from name + import re + match = re.search(r'gfx\d+', name) + if match: + return match.group() + except Exception: + pass + return "unknown" + + +def get_l2_cache_size(device_index: int = 0) -> int: + """ + Get L2 cache size for the specified GPU device. + + Returns: + L2 cache size in bytes + """ + global _l2_cache_size_cache + + if device_index in _l2_cache_size_cache: + return _l2_cache_size_cache[device_index] + + # Check for environment override + if L2_CACHE_OVERRIDE_ENV in os.environ: + try: + size_mb = int(os.environ[L2_CACHE_OVERRIDE_ENV]) + size_bytes = size_mb * 1024 * 1024 + _l2_cache_size_cache[device_index] = size_bytes + return size_bytes + except ValueError: + pass + + # Get architecture and look up cache size + arch = get_gcn_arch_name(device_index) + + # Check exact match first + if arch in AMD_L2_CACHE_SIZES: + size = AMD_L2_CACHE_SIZES[arch] + _l2_cache_size_cache[device_index] = size + return size + + # Check prefix match (e.g., gfx1100 matches gfx1100) + for known_arch, size in AMD_L2_CACHE_SIZES.items(): + if arch.startswith(known_arch): + _l2_cache_size_cache[device_index] = size + return size + + # Default: assume 96 MB (conservative for RDNA3) + default_size = 96 * 1024 * 1024 + _l2_cache_size_cache[device_index] = default_size + return default_size + + +def calculate_optimal_head_group_size( + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0, + l2_utilization: float = 1.0 #use higher utilization by default to improve 1280x720 performance +) -> int: + """ + Calculate the optimal number of heads to process together to fit K,V in L2. + + The calculation is: + K,V memory for N heads = N * seqlen_k * head_dim * dtype_size * 2 (for K and V) + + We want: K,V memory <= L2_cache * utilization + So: N <= (L2_cache * utilization) / (seqlen_k * head_dim * dtype_size * 2) + + Args: + seqlen_k: Sequence length of K/V + head_dim: Head dimension + dtype: Data type of tensors + device_index: GPU device index + l2_utilization: Fraction of L2 to target (default 0.9 to leave room for Q) + + Returns: + Optimal number of heads to process together (minimum 1) + """ + l2_size = get_l2_cache_size(device_index) + + # Get element size in bytes + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif hasattr(torch, 'float8_e4m3fnuz') and dtype == torch.float8_e4m3fnuz: + elem_size = 1 + elif hasattr(torch, 'float8_e5m2fnuz') and dtype == torch.float8_e5m2fnuz: + elem_size = 1 + elif hasattr(torch, 'float8_e4m3fn') and dtype == torch.float8_e4m3fn: + elem_size = 1 + elif hasattr(torch, 'float8_e5m2') and dtype == torch.float8_e5m2: + elem_size = 1 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 # Default to fp16 + + # Memory for K and V per head + kv_per_head = seqlen_k * head_dim * elem_size * 2 # *2 for K and V + + # Target L2 usage (leave some room for Q and other data) + target_l2 = int(l2_size * l2_utilization) + + # Calculate number of heads that fit + if kv_per_head == 0: + return 1 + + head_group_size = max(1, target_l2 // kv_per_head) + + return head_group_size + + +def is_head_grouping_beneficial( + nheads: int, + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0, + threshold_ratio: float = 1.5 +) -> Tuple[bool, int]: + """ + Determine if head grouping would be beneficial and return optimal group size. + + Head grouping is beneficial when: + 1. Total K,V memory exceeds L2 cache size by a significant margin + 2. Processing in groups allows K,V to fit in L2 + 3. The overhead of multiple kernel launches is worth the cache benefit + + Args: + nheads: Number of attention heads + seqlen_k: Sequence length of K/V + head_dim: Head dimension + dtype: Data type + device_index: GPU device index + threshold_ratio: K,V must exceed L2 by this ratio to enable grouping + + Returns: + (should_group, group_size): Whether to group and the optimal group size + """ + # Check if disabled via environment + if os.environ.get(DISABLE_HEAD_GROUPING_ENV, "0") == "1": + return False, nheads + + l2_size = get_l2_cache_size(device_index) + + # Get element size + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif hasattr(torch, 'float8_e4m3fnuz') and dtype == torch.float8_e4m3fnuz: + elem_size = 1 + elif hasattr(torch, 'float8_e5m2fnuz') and dtype == torch.float8_e5m2fnuz: + elem_size = 1 + elif hasattr(torch, 'float8_e4m3fn') and dtype == torch.float8_e4m3fn: + elem_size = 1 + elif hasattr(torch, 'float8_e5m2') and dtype == torch.float8_e5m2: + elem_size = 1 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 + + # Total K,V memory for all heads + total_kv = nheads * seqlen_k * head_dim * elem_size * 2 + + # Only group if K,V significantly exceeds L2 + if total_kv < l2_size * threshold_ratio: + return False, nheads + + # Calculate optimal group size + group_size = calculate_optimal_head_group_size( + seqlen_k, head_dim, dtype, device_index + ) + + # Only group if we'd have at least 2 groups + # (otherwise grouping adds overhead with no benefit) + if group_size >= nheads: + return False, nheads + + # Minimum group size to avoid excessive kernel launches + min_group_size = max(1, nheads // 16) # At most 16 groups + group_size = max(group_size, min_group_size) + + return True, min(group_size, nheads) + + +def print_head_grouping_info( + nheads: int, + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0 +): + """Print diagnostic information about head grouping.""" + l2_size = get_l2_cache_size(device_index) + arch = get_gcn_arch_name(device_index) + + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 + + total_kv = nheads * seqlen_k * head_dim * elem_size * 2 + should_group, group_size = is_head_grouping_beneficial( + nheads, seqlen_k, head_dim, dtype, device_index + ) + + print(f"\n=== L2 Cache-Aware Head Grouping ===") + print(f"GPU: {arch}") + print(f"L2 Cache: {l2_size / (1024*1024):.1f} MB") + print(f"Heads: {nheads}, SeqLen: {seqlen_k}, HeadDim: {head_dim}") + print(f"Total K,V Memory: {total_kv / (1024*1024):.1f} MB") + print(f"L2 Ratio: {total_kv / l2_size:.2f}x") + print(f"Should Group: {should_group}") + if should_group: + kv_per_group = group_size * seqlen_k * head_dim * elem_size * 2 + num_groups = (nheads + group_size - 1) // group_size + print(f"Group Size: {group_size} heads ({num_groups} groups)") + print(f"K,V per Group: {kv_per_group / (1024*1024):.1f} MB") + print("=" * 40 + "\n") From e935f3b5e7b7f36a8b8a017ddae807e03fe561a9 Mon Sep 17 00:00:00 2001 From: Tianwei Yang Date: Tue, 6 Jan 2026 06:43:43 +0000 Subject: [PATCH 5/6] [Navi]renaming L2 cache to Infinity Cache (LLC) to avoid confusion --- .../flash_attn_triton_amd/l2_cache_aware.py | 203 +++++++++--------- 1 file changed, 99 insertions(+), 104 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/l2_cache_aware.py b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py index 7dd851b41e4..f506e330aae 100644 --- a/flash_attn/flash_attn_triton_amd/l2_cache_aware.py +++ b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py @@ -1,18 +1,19 @@ """ -L2 Cache-Aware Head Grouping for Flash Attention +Infinity Cache (LLC) Aware Head Grouping for Flash Attention This module provides functionality to optimize flash attention by processing -heads in groups that fit in the L2 cache. This is particularly important for -consumer AMD GPUs like gfx1100 (RX 7900 XTX) where the L2 cache is smaller -than datacenter GPUs. +heads in groups that fit in the Last Level Cache (LLC / Infinity Cache). -The key insight is that for large sequence lengths, the K and V tensors for -all heads may exceed L2 cache capacity, causing cache thrashing. By processing -heads in groups that fit in L2, we can achieve up to 2x speedup. +AMD RDNA3 cache hierarchy: +- L2 Cache: 6 MB (per-die, fast) +- Infinity Cache (L3/LLC): 96 MB (acts as memory-side cache) -Example: gfx1100 with 96MB L2, 40 heads, seqlen=17160, head_dim=128 -- K,V for all 40 heads = 352 MB (exceeds 96 MB L2) -- K,V for 10 heads = 88 MB (fits in 96 MB L2) +For large sequence lengths, we want K,V to fit in the 96 MB Infinity Cache. +By processing heads in groups that fit, we achieve up to 2x speedup. + +Example: gfx1100 with 96MB Infinity Cache, 40 heads, seqlen=17160, head_dim=128 +- K,V for all 40 heads = 352 MB (exceeds 96 MB LLC) +- K,V for 10 heads = 88 MB (fits in 96 MB LLC) - Processing 10 heads at a time gives 1.95x speedup """ @@ -21,20 +22,28 @@ from typing import Optional, Tuple, Dict import torch -# L2 cache sizes for AMD GPUs in bytes -# Source: AMD documentation and hardware specs -AMD_L2_CACHE_SIZES: Dict[str, int] = { - # RDNA3 workstaion - "gfx1100": 96 * 1024 * 1024, # RX 7900 XTX/XT - 96 MB +# Infinity Cache (LLC) sizes for AMD GPUs in bytes +# Note: This is the L3/Infinity Cache, NOT the L2 cache +# RDNA3: L2=6MB, Infinity Cache (LLC)=96MB +AMD_LLC_CACHE_SIZES: Dict[str, int] = { + # RDNA3 consumer + "gfx1100": 96 * 1024 * 1024, # RX 7900 XTX/XT - 96 MB Infinity Cache + "gfx1101": 64 * 1024 * 1024, # RX 7800 XT - 64 MB Infinity Cache + "gfx1102": 32 * 1024 * 1024, # RX 7600 - 32 MB Infinity Cache } -# Environment variable to override L2 cache size (in MB) -L2_CACHE_OVERRIDE_ENV = "FLASH_ATTN_L2_CACHE_MB" +# Legacy alias for backwards compatibility +AMD_L2_CACHE_SIZES = AMD_LLC_CACHE_SIZES + +# Environment variable to override LLC cache size (in MB) +LLC_CACHE_OVERRIDE_ENV = "FLASH_ATTN_LLC_CACHE_MB" +L2_CACHE_OVERRIDE_ENV = "FLASH_ATTN_L2_CACHE_MB" # Legacy alias + # Environment variable to disable head grouping DISABLE_HEAD_GROUPING_ENV = "FLASH_ATTN_DISABLE_HEAD_GROUPING" -# Cached L2 size per device -_l2_cache_size_cache: Dict[int, int] = {} +# Cached LLC size per device +_llc_cache_size_cache: Dict[int, int] = {} @functools.lru_cache(maxsize=None) @@ -57,90 +66,100 @@ def get_gcn_arch_name(device_index: int = 0) -> str: return "unknown" -def get_l2_cache_size(device_index: int = 0) -> int: +def get_num_cus(device_index: int = 0) -> int: """ - Get L2 cache size for the specified GPU device. + Get the number of Compute Units for an AMD GPU. + + Note: PyTorch's multi_processor_count may be incorrect for some AMD GPUs. + We use known values for common architectures. + """ + arch = get_gcn_arch_name(device_index) + + # Known CU counts for common GPUs + known_cus = { + "gfx1100": 96, # RX 7900 XTX + "gfx1101": 60, # RX 7800 XT + "gfx1102": 32, # RX 7600 + } + + if arch in known_cus: + return known_cus[arch] + + # Fallback to PyTorch (may be incorrect) + try: + props = torch.cuda.get_device_properties(device_index) + return props.multi_processor_count + except Exception: + return 96 # Default + + +def get_llc_cache_size(device_index: int = 0) -> int: + """ + Get Infinity Cache (LLC) size for the specified GPU device. + + For RDNA3, this is the 96 MB Infinity Cache, not the 6 MB L2. Returns: - L2 cache size in bytes + LLC cache size in bytes """ - global _l2_cache_size_cache + global _llc_cache_size_cache - if device_index in _l2_cache_size_cache: - return _l2_cache_size_cache[device_index] + if device_index in _llc_cache_size_cache: + return _llc_cache_size_cache[device_index] - # Check for environment override - if L2_CACHE_OVERRIDE_ENV in os.environ: - try: - size_mb = int(os.environ[L2_CACHE_OVERRIDE_ENV]) - size_bytes = size_mb * 1024 * 1024 - _l2_cache_size_cache[device_index] = size_bytes - return size_bytes - except ValueError: - pass + # Check for environment override (new name first, then legacy) + for env_var in [LLC_CACHE_OVERRIDE_ENV, L2_CACHE_OVERRIDE_ENV]: + if env_var in os.environ: + try: + size_mb = int(os.environ[env_var]) + size_bytes = size_mb * 1024 * 1024 + _llc_cache_size_cache[device_index] = size_bytes + return size_bytes + except ValueError: + pass # Get architecture and look up cache size arch = get_gcn_arch_name(device_index) # Check exact match first - if arch in AMD_L2_CACHE_SIZES: - size = AMD_L2_CACHE_SIZES[arch] - _l2_cache_size_cache[device_index] = size + if arch in AMD_LLC_CACHE_SIZES: + size = AMD_LLC_CACHE_SIZES[arch] + _llc_cache_size_cache[device_index] = size return size # Check prefix match (e.g., gfx1100 matches gfx1100) - for known_arch, size in AMD_L2_CACHE_SIZES.items(): + for known_arch, size in AMD_LLC_CACHE_SIZES.items(): if arch.startswith(known_arch): - _l2_cache_size_cache[device_index] = size + _llc_cache_size_cache[device_index] = size return size # Default: assume 96 MB (conservative for RDNA3) default_size = 96 * 1024 * 1024 - _l2_cache_size_cache[device_index] = default_size + _llc_cache_size_cache[device_index] = default_size return default_size +# Legacy alias +get_l2_cache_size = get_llc_cache_size + + def calculate_optimal_head_group_size( seqlen_k: int, head_dim: int, dtype: torch.dtype, device_index: int = 0, - l2_utilization: float = 1.0 #use higher utilization by default to improve 1280x720 performance + llc_utilization: float = 1.0 # Use higher utilization by default ) -> int: """ - Calculate the optimal number of heads to process together to fit K,V in L2. - - The calculation is: - K,V memory for N heads = N * seqlen_k * head_dim * dtype_size * 2 (for K and V) - - We want: K,V memory <= L2_cache * utilization - So: N <= (L2_cache * utilization) / (seqlen_k * head_dim * dtype_size * 2) - - Args: - seqlen_k: Sequence length of K/V - head_dim: Head dimension - dtype: Data type of tensors - device_index: GPU device index - l2_utilization: Fraction of L2 to target (default 0.9 to leave room for Q) - - Returns: - Optimal number of heads to process together (minimum 1) + Calculate the optimal number of heads to process together to fit K,V in LLC. """ - l2_size = get_l2_cache_size(device_index) + llc_size = get_llc_cache_size(device_index) # Get element size in bytes if dtype in (torch.float16, torch.bfloat16): elem_size = 2 elif dtype == torch.float32: elem_size = 4 - elif hasattr(torch, 'float8_e4m3fnuz') and dtype == torch.float8_e4m3fnuz: - elem_size = 1 - elif hasattr(torch, 'float8_e5m2fnuz') and dtype == torch.float8_e5m2fnuz: - elem_size = 1 - elif hasattr(torch, 'float8_e4m3fn') and dtype == torch.float8_e4m3fn: - elem_size = 1 - elif hasattr(torch, 'float8_e5m2') and dtype == torch.float8_e5m2: - elem_size = 1 elif 'float8' in str(dtype).lower(): elem_size = 1 else: @@ -149,14 +168,14 @@ def calculate_optimal_head_group_size( # Memory for K and V per head kv_per_head = seqlen_k * head_dim * elem_size * 2 # *2 for K and V - # Target L2 usage (leave some room for Q and other data) - target_l2 = int(l2_size * l2_utilization) + # Target LLC usage + target_llc = int(llc_size * llc_utilization) # Calculate number of heads that fit if kv_per_head == 0: return 1 - head_group_size = max(1, target_l2 // kv_per_head) + head_group_size = max(1, target_llc // kv_per_head) return head_group_size @@ -171,42 +190,18 @@ def is_head_grouping_beneficial( ) -> Tuple[bool, int]: """ Determine if head grouping would be beneficial and return optimal group size. - - Head grouping is beneficial when: - 1. Total K,V memory exceeds L2 cache size by a significant margin - 2. Processing in groups allows K,V to fit in L2 - 3. The overhead of multiple kernel launches is worth the cache benefit - - Args: - nheads: Number of attention heads - seqlen_k: Sequence length of K/V - head_dim: Head dimension - dtype: Data type - device_index: GPU device index - threshold_ratio: K,V must exceed L2 by this ratio to enable grouping - - Returns: - (should_group, group_size): Whether to group and the optimal group size """ # Check if disabled via environment if os.environ.get(DISABLE_HEAD_GROUPING_ENV, "0") == "1": return False, nheads - l2_size = get_l2_cache_size(device_index) + llc_size = get_llc_cache_size(device_index) # Get element size if dtype in (torch.float16, torch.bfloat16): elem_size = 2 elif dtype == torch.float32: elem_size = 4 - elif hasattr(torch, 'float8_e4m3fnuz') and dtype == torch.float8_e4m3fnuz: - elem_size = 1 - elif hasattr(torch, 'float8_e5m2fnuz') and dtype == torch.float8_e5m2fnuz: - elem_size = 1 - elif hasattr(torch, 'float8_e4m3fn') and dtype == torch.float8_e4m3fn: - elem_size = 1 - elif hasattr(torch, 'float8_e5m2') and dtype == torch.float8_e5m2: - elem_size = 1 elif 'float8' in str(dtype).lower(): elem_size = 1 else: @@ -215,8 +210,8 @@ def is_head_grouping_beneficial( # Total K,V memory for all heads total_kv = nheads * seqlen_k * head_dim * elem_size * 2 - # Only group if K,V significantly exceeds L2 - if total_kv < l2_size * threshold_ratio: + # Only group if K,V significantly exceeds LLC + if total_kv < llc_size * threshold_ratio: return False, nheads # Calculate optimal group size @@ -225,7 +220,6 @@ def is_head_grouping_beneficial( ) # Only group if we'd have at least 2 groups - # (otherwise grouping adds overhead with no benefit) if group_size >= nheads: return False, nheads @@ -244,8 +238,9 @@ def print_head_grouping_info( device_index: int = 0 ): """Print diagnostic information about head grouping.""" - l2_size = get_l2_cache_size(device_index) + llc_size = get_llc_cache_size(device_index) arch = get_gcn_arch_name(device_index) + num_cus = get_num_cus(device_index) if dtype in (torch.float16, torch.bfloat16): elem_size = 2 @@ -261,16 +256,16 @@ def print_head_grouping_info( nheads, seqlen_k, head_dim, dtype, device_index ) - print(f"\n=== L2 Cache-Aware Head Grouping ===") - print(f"GPU: {arch}") - print(f"L2 Cache: {l2_size / (1024*1024):.1f} MB") + print(f"\n=== Infinity Cache (LLC) Aware Head Grouping ===") + print(f"GPU: {arch} ({num_cus} CUs)") + print(f"Infinity Cache (LLC): {llc_size / (1024*1024):.1f} MB") print(f"Heads: {nheads}, SeqLen: {seqlen_k}, HeadDim: {head_dim}") print(f"Total K,V Memory: {total_kv / (1024*1024):.1f} MB") - print(f"L2 Ratio: {total_kv / l2_size:.2f}x") + print(f"LLC Ratio: {total_kv / llc_size:.2f}x") print(f"Should Group: {should_group}") if should_group: kv_per_group = group_size * seqlen_k * head_dim * elem_size * 2 num_groups = (nheads + group_size - 1) // group_size print(f"Group Size: {group_size} heads ({num_groups} groups)") print(f"K,V per Group: {kv_per_group / (1024*1024):.1f} MB") - print("=" * 40 + "\n") + print("=" * 48 + "\n") From 92cc73ac2bb27a044d83f99f6fdf30e7f7e036a1 Mon Sep 17 00:00:00 2001 From: Tianwei Yang Date: Fri, 9 Jan 2026 01:23:17 +0000 Subject: [PATCH 6/6] [ROCM]Optimized for gfx1100 (RDNA3) with LLC-aware head grouping for long seqlen --- .../flash_attn_triton_amd/fwd_prefill.py | 27 +++++++++---------- .../flash_attn_triton_amd/l2_cache_aware.py | 2 +- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 7e814be4118..b0b320321b6 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -186,21 +186,18 @@ def get_cdna_autotune_configs(): def get_rdna_autotune_configs(): return [ - # === Configs for head_dim=64 (optimal) === - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # Best config from autotune on gfx1100: 32x16, warps=2, PRE_LOAD_V=True + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + # === Configs for head_dim=128 === + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), + # === Fallback configs === triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - # === Configs for head_dim=128 (Wan2.2) - smaller blocks to reduce register pressure === - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - # === General fallback configs === - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] @@ -215,9 +212,9 @@ def get_autotune_configs(): raise ValueError("Unknown Device Type") else: return [ - # Use BLOCK_N=32 to avoid register spilling on gfx1100 + # Optimized for gfx1100 (RDNA3) with LLC-aware head grouping triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "waves_per_eu": 1, "PRE_LOAD_V": True}, + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": True}, num_stages=1, num_warps=4, ), diff --git a/flash_attn/flash_attn_triton_amd/l2_cache_aware.py b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py index f506e330aae..981f8ce5702 100644 --- a/flash_attn/flash_attn_triton_amd/l2_cache_aware.py +++ b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py @@ -148,7 +148,7 @@ def calculate_optimal_head_group_size( head_dim: int, dtype: torch.dtype, device_index: int = 0, - llc_utilization: float = 1.0 # Use higher utilization by default + llc_utilization: float = 1.5 # Use 150% of LLC - optimal for long sequences ) -> int: """ Calculate the optimal number of heads to process together to fit K,V in LLC.