diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index ef8a9d5ff45..f748b0e6a30 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -11,6 +11,7 @@ get_arch, is_fp8, ) +from .llc_cache_aware import is_head_grouping_beneficial @@ -78,10 +79,15 @@ def get_fwd_prefill_configs(autotune: bool): elif arch.is_rdna: return [ triton.Config( - {"BLOCK_M": 32, "BLOCK_N": 32, "PRE_LOAD_V": False}, + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, num_stages=1, num_warps=4, - ), + ) ] else: return [ @@ -104,7 +110,7 @@ def get_fwd_prefill_configs(autotune: bool): NUM_WARPS_OPTIONS = [2, 4, 8] NUM_STAGES_OPTIONS = [1, 2] WAVES_PER_EU_OPTIONS = [4, 2, 1] - PRE_LOAD_V_OPTIONS = [False] + PRE_LOAD_V_OPTIONS = [False, True] for bm in BLOCK_M_OPTIONS: for bn in BLOCK_N_OPTIONS: for waves in WAVES_PER_EU_OPTIONS: @@ -1265,7 +1271,7 @@ def attn_fwd( tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) -def attention_forward_prefill_triton_impl( +def _attention_forward_prefill_triton_impl_core( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -1741,3 +1747,179 @@ def attention_forward_prefill_triton_impl( USE_SEQUSED=(seqused_q is not None or seqused_k is not None), FORCE_MASKING=force_masking, ) + +def attention_forward_prefill_triton_impl( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + sd_mask: Optional[torch.Tensor], + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + window_size_left: int, + window_size_right: int, + bias: Optional[torch.Tensor], + layout: Literal["bshd", "bhsd", "thd"], + # varlen + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlens_q: int, + max_seqlens_k: int, + # dropout + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + # misc + return_scores: bool, + use_exp2: bool, + # fp8 + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + # seqused for FA v3 + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + # rotary (optional) + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + rotary_interleaved: bool = False, + seqlens_rotary: Optional[torch.Tensor] = None, +): + """ + Wrapper for attention forward with LLC-aware head grouping optimization. + + For long sequences on GPUs with large LLC (e.g., RDNA3 with 96 MB Infinity Cache), + processing heads in groups that fit K,V in cache can significantly improve performance. + """ + IS_VARLEN = layout == "thd" + + # Get head dimensions + if IS_VARLEN: + total_q, nheads_q, head_dim = q.shape + nheads_k = k.shape[1] + else: + batch, seqlen_q, nheads_q, head_dim = q.shape + nheads_k = k.shape[2] + + # Check if head grouping is beneficial + should_group, group_size = is_head_grouping_beneficial( + nheads_k, max_seqlens_k, head_dim, q.dtype, q.device.index or 0 + ) + + # Disable head grouping if return_scores is requested (need full attention matrix) + # or if sd_mask is provided + if return_scores or sd_mask is not None: + should_group = False + + if not should_group or group_size >= nheads_q: + # No grouping needed - call core implementation directly + return _attention_forward_prefill_triton_impl_core( + q, k, v, o, softmax_lse, sd_mask, + sm_scale, alibi_slopes, causal, + window_size_left, window_size_right, bias, layout, + cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k, + dropout_p, philox_seed, philox_offset, + return_scores, use_exp2, + q_descale, k_descale, v_descale, + seqused_q, seqused_k, + rotary_cos, rotary_sin, rotary_interleaved, seqlens_rotary, + ) + + # Head grouping path + if DEBUG: + print(f"[LLC Head Grouping fwd_prefill] Processing {nheads_q} heads in groups of {group_size}") + + gqa_ratio = nheads_q // nheads_k + n_groups = (nheads_q + group_size - 1) // group_size + + # Calculate K,V heads per group (for GQA) + group_size_k = (group_size + gqa_ratio - 1) // gqa_ratio + + # Pre-allocate K,V buffers to avoid repeated allocations in loop + # This reuses memory across iterations instead of calling .contiguous() each time + if IS_VARLEN: + # thd layout: (total_tokens, nheads_k_group, head_dim) + k_buffer = torch.empty((total_q, group_size_k, head_dim), device=k.device, dtype=k.dtype) + v_buffer = torch.empty((total_q, group_size_k, head_dim), device=v.device, dtype=v.dtype) + else: + # bshd layout: (batch, seqlen_k, nheads_k_group, head_dim) + seqlen_k = k.shape[1] + k_buffer = torch.empty((batch, seqlen_k, group_size_k, head_dim), device=k.device, dtype=k.dtype) + v_buffer = torch.empty((batch, seqlen_k, group_size_k, head_dim), device=v.device, dtype=v.dtype) + + softmax_lse_list = [] + + for g in range(n_groups): + start_h = g * group_size + end_h = min((g + 1) * group_size, nheads_q) + actual_heads = end_h - start_h + + # For GQA, compute corresponding K,V head range + start_h_k = start_h // gqa_ratio + end_h_k = (end_h + gqa_ratio - 1) // gqa_ratio + actual_heads_k = end_h_k - start_h_k + + if IS_VARLEN: + # thd layout: (total_tokens, nheads, head_dim) + q_group = q[:, start_h:end_h, :] # strided view + o_group = o[:, start_h:end_h, :] # strided view, write directly + + # Copy K,V into pre-allocated buffers + k_group = k_buffer[:, :actual_heads_k, :] + v_group = v_buffer[:, :actual_heads_k, :] + k_group.copy_(k[:, start_h_k:end_h_k, :]) + v_group.copy_(v[:, start_h_k:end_h_k, :]) + + # softmax_lse for varlen: (Hq, Total_Q) + softmax_lse_group = torch.zeros( + (actual_heads, total_q), device=q.device, dtype=torch.float32 + ) + else: + # bshd layout: (batch, seqlen, nheads, head_dim) + q_group = q[:, :, start_h:end_h, :] # strided view + o_group = o[:, :, start_h:end_h, :] # strided view, write directly + + # Copy K,V into pre-allocated buffers + k_group = k_buffer[:, :, :actual_heads_k, :] + v_group = v_buffer[:, :, :actual_heads_k, :] + k_group.copy_(k[:, :, start_h_k:end_h_k, :]) + v_group.copy_(v[:, :, start_h_k:end_h_k, :]) + + # softmax_lse for bshd: (B, Hq, Sq) + softmax_lse_group = torch.zeros( + (batch, actual_heads, softmax_lse.shape[-1]), + device=q.device, dtype=torch.float32, + ) + + # Handle alibi slopes if present + alibi_group = None + if alibi_slopes is not None: + alibi_group = alibi_slopes[:, start_h:end_h] if alibi_slopes.dim() == 2 else alibi_slopes[start_h:end_h] + + # Call core implementation for this group + _attention_forward_prefill_triton_impl_core( + q_group, k_group, v_group, o_group, softmax_lse_group, None, + sm_scale, alibi_group, causal, + window_size_left, window_size_right, bias, layout, + cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k, + dropout_p, philox_seed, philox_offset, + False, use_exp2, + q_descale, k_descale, v_descale, + seqused_q, seqused_k, + rotary_cos, rotary_sin, rotary_interleaved, seqlens_rotary, + ) + + softmax_lse_list.append(softmax_lse_group) + + # Concatenate softmax_lse across heads + if IS_VARLEN: + # varlen: (Hq, Total_Q) - concat on dim 0 + final_lse = torch.cat(softmax_lse_list, dim=0) + else: + # bshd: (B, Hq, Sq) - concat on dim 1 + final_lse = torch.cat(softmax_lse_list, dim=1) + + # Copy back to caller's softmax_lse tensor + softmax_lse.copy_(final_lse) diff --git a/flash_attn/flash_attn_triton_amd/llc_cache_aware.py b/flash_attn/flash_attn_triton_amd/llc_cache_aware.py new file mode 100644 index 00000000000..d1704843c78 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/llc_cache_aware.py @@ -0,0 +1,237 @@ +""" +Infinity Cache (LLC) Aware Head Grouping for Flash Attention + +This module provides functionality to optimize flash attention by processing +heads in groups that fit in the Last Level Cache (LLC / Infinity Cache). + +AMD RDNA3 cache hierarchy: +- L2 Cache: 6 MB (per-die, fast) +- Infinity Cache (L3/LLC): 96 MB (acts as memory-side cache) + +For large sequence lengths, we want K,V to fit in the 96 MB Infinity Cache. +By processing heads in groups that fit, we achieve up to 2x speedup. + +Example: gfx1100 with 96MB Infinity Cache, 40 heads, seqlen=17160, head_dim=128 +- K,V for all 40 heads = 352 MB (exceeds 96 MB LLC) +- K,V for 10 heads = 88 MB (fits in 96 MB LLC) +- Processing 10 heads at a time gives 1.95x speedup +""" + +import os +from typing import Tuple, Dict +import torch + +from .utils import get_arch + +# Infinity Cache (LLC) sizes for AMD GPUs in bytes +# Note: This is the L3/Infinity Cache, NOT the L2 cache +# RDNA3: L2=6MB, Infinity Cache (LLC)=96MB +AMD_LLC_CACHE_SIZES: Dict[str, int] = { + # RDNA2 + "gfx1030": 128 * 1024 * 1024, # RX 6900 XT - 128 MB Infinity Cache + # RDNA3 consumer + "gfx1100": 96 * 1024 * 1024, # RX 7900 XTX - 96 MB Infinity Cache + "gfx1101": 64 * 1024 * 1024, # RX 7800 XT - 64 MB Infinity Cache + "gfx1102": 32 * 1024 * 1024, # RX 7600 - 32 MB Infinity Cache + # RDNA4 + "gfx1200": 32 * 1024 * 1024, # RX 9060/XT - 32 MB Infinity Cache + "gfx1201": 64 * 1024 * 1024, # RX 9070/XT - 64 MB Infinity Cache +} + +# Legacy alias for backwards compatibility +AMD_L2_CACHE_SIZES = AMD_LLC_CACHE_SIZES + +# Environment variable to override LLC cache size (in MB) +LLC_CACHE_OVERRIDE_ENV = "FLASH_ATTN_LLC_CACHE_MB" +L2_CACHE_OVERRIDE_ENV = "FLASH_ATTN_L2_CACHE_MB" # Legacy alias + +# Environment variable to disable head grouping +DISABLE_HEAD_GROUPING_ENV = "FLASH_ATTN_DISABLE_HEAD_GROUPING" + +# Cached LLC size per device +_llc_cache_size_cache: Dict[int, int] = {} + + +def get_llc_cache_size(device_index: int = 0) -> int: + """ + Get Infinity Cache (LLC) size for the specified GPU device. + + For RDNA3, this is the 96 MB Infinity Cache, not the 6 MB L2. + + Returns: + LLC cache size in bytes + """ + global _llc_cache_size_cache + + if device_index in _llc_cache_size_cache: + return _llc_cache_size_cache[device_index] + + # Check for environment override (new name first, then legacy) + for env_var in [LLC_CACHE_OVERRIDE_ENV, L2_CACHE_OVERRIDE_ENV]: + if env_var in os.environ: + try: + size_mb = int(os.environ[env_var]) + size_bytes = size_mb * 1024 * 1024 + _llc_cache_size_cache[device_index] = size_bytes + return size_bytes + except ValueError: + pass + + # Get architecture using utils.get_arch() + arch = get_arch().name + + # Check exact match first + if arch in AMD_LLC_CACHE_SIZES: + size = AMD_LLC_CACHE_SIZES[arch] + _llc_cache_size_cache[device_index] = size + return size + + # Check prefix match (e.g., gfx1100 matches gfx1100) + for known_arch, size in AMD_LLC_CACHE_SIZES.items(): + if arch.startswith(known_arch): + _llc_cache_size_cache[device_index] = size + return size + + # Default: assume 96 MB (conservative for RDNA3) + default_size = 96 * 1024 * 1024 + _llc_cache_size_cache[device_index] = default_size + return default_size + + +# Legacy alias +get_l2_cache_size = get_llc_cache_size + + +def calculate_optimal_head_group_size( + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0, + llc_utilization: float = 1.0 # Use 100% of LLC - optimal for long sequences +) -> int: + """ + Calculate the optimal number of heads to process together to fit K,V in LLC. + """ + llc_size = get_llc_cache_size(device_index) + + # Get element size in bytes + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 # Default to fp16 + + # Memory for K and V per head + kv_per_head = seqlen_k * head_dim * elem_size * 2 # *2 for K and V + + # Target LLC usage + target_llc = int(llc_size * llc_utilization) + + # Calculate number of heads that fit + if kv_per_head == 0: + return 1 + + head_group_size = max(1, target_llc // kv_per_head) + + return head_group_size + + +def is_head_grouping_beneficial( + nheads: int, + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0, + threshold_ratio: float = 1.5 +) -> Tuple[bool, int]: + """ + Determine if head grouping would be beneficial and return optimal group size. + + Head grouping is only beneficial for RDNA GPUs with Infinity Cache (LLC). + CDNA GPUs (MI250, MI300, etc.) have different cache architectures. + """ + # Check if disabled via environment + if os.environ.get(DISABLE_HEAD_GROUPING_ENV, "0") == "1": + return False, nheads + + # Only apply head grouping to RDNA GPUs (which have Infinity Cache) + arch = get_arch() + if not arch.is_rdna: + return False, nheads + + llc_size = get_llc_cache_size(device_index) + + # Get element size + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 + + # Total K,V memory for all heads + total_kv = nheads * seqlen_k * head_dim * elem_size * 2 + + # Only group if K,V significantly exceeds LLC + if total_kv < llc_size * threshold_ratio: + return False, nheads + + # Calculate optimal group size + group_size = calculate_optimal_head_group_size( + seqlen_k, head_dim, dtype, device_index + ) + + # Only group if we'd have at least 2 groups + if group_size >= nheads: + return False, nheads + + # Minimum group size to avoid excessive kernel launches + min_group_size = max(1, nheads // 16) # At most 16 groups + group_size = max(group_size, min_group_size) + + return True, min(group_size, nheads) + + +def print_head_grouping_info( + nheads: int, + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0 +): + """Print diagnostic information about head grouping.""" + llc_size = get_llc_cache_size(device_index) + arch = get_arch() + + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 + + total_kv = nheads * seqlen_k * head_dim * elem_size * 2 + should_group, group_size = is_head_grouping_beneficial( + nheads, seqlen_k, head_dim, dtype, device_index + ) + + print(f"\n=== Infinity Cache (LLC) Aware Head Grouping ===") + print(f"GPU: {arch.name}") + print(f"Infinity Cache (LLC): {llc_size / (1024*1024):.1f} MB") + print(f"Heads: {nheads}, SeqLen: {seqlen_k}, HeadDim: {head_dim}") + print(f"Total K,V Memory: {total_kv / (1024*1024):.1f} MB") + print(f"LLC Ratio: {total_kv / llc_size:.2f}x") + print(f"Should Group: {should_group}") + if should_group: + kv_per_group = group_size * seqlen_k * head_dim * elem_size * 2 + num_groups = (nheads + group_size - 1) // group_size + print(f"Group Size: {group_size} heads ({num_groups} groups)") + print(f"K,V per Group: {kv_per_group / (1024*1024):.1f} MB") + print("=" * 48 + "\n")