diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index 3f2d92c22d6..5c16cf4c552 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -2,7 +2,7 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna +from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna, is_rdna def get_cdna_autotune_configs(): return [ @@ -23,6 +23,26 @@ def get_cdna_autotune_configs(): num_warps=4), ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] +def get_rdna_autotune_configs(): + return [ + # Most aggressive - 128x128 (best for large sequences) + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # Large blocks + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + # Medium blocks + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] + def get_autotune_configs(): if AUTOTUNE: if is_cdna(): @@ -30,6 +50,11 @@ def get_autotune_configs(): fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) + elif is_rdna(): + autotune_configs, autotune_keys = get_rdna_autotune_configs() + fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys + reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys + return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) else: raise ValueError("Unknown Device Type") else: diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 6f69cd02813..b0b320321b6 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -186,24 +186,22 @@ def get_cdna_autotune_configs(): def get_rdna_autotune_configs(): return [ - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), + # Best config from autotune on gfx1100: 32x16, warps=2, PRE_LOAD_V=True + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + # === Configs for head_dim=128 === + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), + # === Fallback configs === + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] + def get_autotune_configs(): if AUTOTUNE: if is_rdna(): @@ -214,8 +212,9 @@ def get_autotune_configs(): raise ValueError("Unknown Device Type") else: return [ + # Optimized for gfx1100 (RDNA3) with LLC-aware head grouping triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": True}, num_stages=1, num_warps=4, ), diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 06ab7d24d56..c223ee93b6c 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -9,11 +9,17 @@ from .fwd_ref import attention_forward_pytorch_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl from .utils import DEBUG, USE_REF, MetaData, get_shapes_from_layout, is_fp8 +from .l2_cache_aware import is_head_grouping_beneficial, print_head_grouping_info from einops import rearrange, repeat from flash_attn.layers.rotary import apply_rotary_emb from typing import Literal, Optional, Union -def fwd(q: torch.Tensor, +# Environment variable to enable verbose head grouping output +L2_HEAD_GROUPING_DEBUG = os.environ.get('FLASH_ATTN_HEAD_GROUPING_DEBUG', '0') == '1' + + + +def _fwd_single_group(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: Optional[torch.Tensor], @@ -31,6 +37,7 @@ def fwd(q: torch.Tensor, descale_v: Optional[torch.Tensor] = None, descale_o: Optional[torch.Tensor] = None ): + """Original fwd implementation for a single head group.""" if DEBUG: print() @@ -145,6 +152,112 @@ def fwd(q: torch.Tensor, return out, softmax_lse, sd_mask, rng_state + +def fwd(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None + ): + """ + Flash attention forward with L2 cache-aware head grouping. + + For consumer AMD GPUs (e.g., gfx1100 with 96MB L2), when K,V tensors + exceed L2 cache capacity, processing heads in groups that fit in L2 + can provide up to 2x speedup. + + Layout: bshd (batch, seqlen, heads, head_dim) + """ + # Get shapes for head grouping decision + # Layout is bshd: [batch, seqlen, heads, head_dim] + batch, seqlen_q, nheads_q, head_dim = q.shape + seqlen_k = k.shape[1] + nheads_k = k.shape[2] + + # Check if head grouping is beneficial + should_group, group_size = is_head_grouping_beneficial( + nheads_q, seqlen_k, head_dim, q.dtype, q.device.index or 0 + ) + + if L2_HEAD_GROUPING_DEBUG: + print_head_grouping_info(nheads_q, seqlen_k, head_dim, q.dtype, q.device.index or 0) + + if not should_group or group_size >= nheads_q: + # No grouping needed - use original implementation + return _fwd_single_group( + q, k, v, out, alibi_slopes, dropout_p, softmax_scale, causal, + window_size_left, window_size_right, softcap, return_softmax, + gen_, descale_q, descale_k, descale_v, descale_o + ) + + # Process heads in groups for L2 cache efficiency + if L2_HEAD_GROUPING_DEBUG: + print(f"[L2 Head Grouping] Processing {nheads_q} heads in groups of {group_size}") + + # Prepare output tensor + if is_fp8(q): + assert out is not None, "fp8 output tensor should be passed in." + else: + out = torch.zeros_like(q) if out is None else out.zero_() + + # Collect outputs for each group + softmax_lse_list = [] + rng_state = None + + n_groups = (nheads_q + group_size - 1) // group_size + + for g in range(n_groups): + start_h = g * group_size + end_h = min((g + 1) * group_size, nheads_q) + + # Slice heads: bshd layout -> select heads on dim 2 + q_group = q[:, :, start_h:end_h, :].contiguous() + k_group = k[:, :, start_h:end_h, :].contiguous() + v_group = v[:, :, start_h:end_h, :].contiguous() + out_group = out[:, :, start_h:end_h, :].contiguous() + + # Handle alibi slopes if present + alibi_group = None + if alibi_slopes is not None: + alibi_group = alibi_slopes[start_h:end_h] if alibi_slopes.dim() == 1 else alibi_slopes[:, start_h:end_h] + + # Handle descale tensors for fp8 + descale_q_g = descale_q[:, start_h:end_h] if descale_q is not None else None + descale_k_g = descale_k[:, start_h:end_h] if descale_k is not None else None + descale_v_g = descale_v[:, start_h:end_h] if descale_v is not None else None + descale_o_g = descale_o[:, start_h:end_h] if descale_o is not None else None + + # Call the original implementation for this group + out_g, softmax_lse_g, sd_mask_g, rng_state = _fwd_single_group( + q_group, k_group, v_group, out_group, alibi_group, + dropout_p, softmax_scale, causal, + window_size_left, window_size_right, softcap, return_softmax, + gen_, descale_q_g, descale_k_g, descale_v_g, descale_o_g + ) + + # Copy output back to the main tensor + out[:, :, start_h:end_h, :] = out_g + softmax_lse_list.append(softmax_lse_g) + + # Concatenate softmax_lse across heads + softmax_lse = torch.cat(softmax_lse_list, dim=1) # Assuming lse is [batch, heads, ...] + + return out, softmax_lse, None, rng_state + + + BWD_MODE = os.environ.get('BWD_MODE', 'split').lower() def bwd( dout: torch.Tensor, @@ -349,7 +462,7 @@ def bwd( print("descale_dq:", descale_dq, descale_dq.shape if descale_dq is not None else None) return dq, dk, dv, delta -def varlen_fwd( +def _varlen_fwd_single_group( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -376,6 +489,7 @@ def varlen_fwd( descale_v: Optional[torch.Tensor] = None, descale_o: Optional[torch.Tensor] = None ): + """Original varlen_fwd implementation for a single head group.""" if DEBUG: print() @@ -490,6 +604,121 @@ def varlen_fwd( return out, softmax_lse, sd_mask, rng_state + +def varlen_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + block_table_: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool , + causal: bool , + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None + ): + """ + Variable-length flash attention forward with L2 cache-aware head grouping. + + For consumer AMD GPUs (e.g., gfx1100 with 96MB L2), when K,V tensors + exceed L2 cache capacity, processing heads in groups that fit in L2 + can provide up to 2x speedup. + + Layout: thd (total_seqlen, heads, head_dim) + """ + # Get shapes for head grouping decision + # Layout is thd: [total_seqlen, heads, head_dim] + total_seqlen, nheads_q, head_dim = q.shape + nheads_k = k.shape[1] + + # Check if head grouping is beneficial + should_group, group_size = is_head_grouping_beneficial( + nheads_q, max_seqlen_k, head_dim, q.dtype, q.device.index or 0 + ) + + if L2_HEAD_GROUPING_DEBUG: + print_head_grouping_info(nheads_q, max_seqlen_k, head_dim, q.dtype, q.device.index or 0) + + if not should_group or group_size >= nheads_q: + # No grouping needed - use original implementation + return _varlen_fwd_single_group( + q, k, v, out, cu_seqlens_q, cu_seqlens_k, seqused_k, leftpad_k, + block_table_, alibi_slopes, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, zero_tensors, causal, + window_size_left, window_size_right, softcap, return_softmax, + gen_, descale_q, descale_k, descale_v, descale_o + ) + + # Process heads in groups for L2 cache efficiency + if L2_HEAD_GROUPING_DEBUG: + print(f"[L2 Head Grouping varlen] Processing {nheads_q} heads in groups of {group_size}") + + # Prepare output tensor + if is_fp8(q): + assert out is not None, "fp8 output tensor should be passed in." + else: + out = torch.zeros_like(q) if out is None else out.zero_() + + # Collect outputs for each group + softmax_lse_list = [] + rng_state = None + + n_groups = (nheads_q + group_size - 1) // group_size + + for g in range(n_groups): + start_h = g * group_size + end_h = min((g + 1) * group_size, nheads_q) + + # Slice heads: thd layout -> select heads on dim 1 + q_group = q[:, start_h:end_h, :].contiguous() + k_group = k[:, start_h:end_h, :].contiguous() + v_group = v[:, start_h:end_h, :].contiguous() + out_group = out[:, start_h:end_h, :].contiguous() + + # Handle alibi slopes if present + alibi_group = None + if alibi_slopes is not None: + alibi_group = alibi_slopes[start_h:end_h] if alibi_slopes.dim() == 1 else alibi_slopes[:, start_h:end_h] + + # Handle descale tensors for fp8 + descale_q_g = descale_q[:, start_h:end_h] if descale_q is not None else None + descale_k_g = descale_k[:, start_h:end_h] if descale_k is not None else None + descale_v_g = descale_v[:, start_h:end_h] if descale_v is not None else None + descale_o_g = descale_o[:, start_h:end_h] if descale_o is not None else None + + # Call the original implementation for this group + out_g, softmax_lse_g, sd_mask_g, rng_state = _varlen_fwd_single_group( + q_group, k_group, v_group, out_group, cu_seqlens_q, cu_seqlens_k, + seqused_k, leftpad_k, block_table_, alibi_group, + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, + zero_tensors, causal, window_size_left, window_size_right, + softcap, return_softmax, gen_, descale_q_g, descale_k_g, descale_v_g, descale_o_g + ) + + # Copy output back to the main tensor + out[:, start_h:end_h, :] = out_g + softmax_lse_list.append(softmax_lse_g) + + # Concatenate softmax_lse across heads + softmax_lse = torch.cat(softmax_lse_list, dim=0) # varlen lse is [heads, total_seqlen] + + return out, softmax_lse, None, rng_state + def varlen_bwd( dout: torch.Tensor, q: torch.Tensor, diff --git a/flash_attn/flash_attn_triton_amd/l2_cache_aware.py b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py new file mode 100644 index 00000000000..981f8ce5702 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/l2_cache_aware.py @@ -0,0 +1,271 @@ +""" +Infinity Cache (LLC) Aware Head Grouping for Flash Attention + +This module provides functionality to optimize flash attention by processing +heads in groups that fit in the Last Level Cache (LLC / Infinity Cache). + +AMD RDNA3 cache hierarchy: +- L2 Cache: 6 MB (per-die, fast) +- Infinity Cache (L3/LLC): 96 MB (acts as memory-side cache) + +For large sequence lengths, we want K,V to fit in the 96 MB Infinity Cache. +By processing heads in groups that fit, we achieve up to 2x speedup. + +Example: gfx1100 with 96MB Infinity Cache, 40 heads, seqlen=17160, head_dim=128 +- K,V for all 40 heads = 352 MB (exceeds 96 MB LLC) +- K,V for 10 heads = 88 MB (fits in 96 MB LLC) +- Processing 10 heads at a time gives 1.95x speedup +""" + +import os +import functools +from typing import Optional, Tuple, Dict +import torch + +# Infinity Cache (LLC) sizes for AMD GPUs in bytes +# Note: This is the L3/Infinity Cache, NOT the L2 cache +# RDNA3: L2=6MB, Infinity Cache (LLC)=96MB +AMD_LLC_CACHE_SIZES: Dict[str, int] = { + # RDNA3 consumer + "gfx1100": 96 * 1024 * 1024, # RX 7900 XTX/XT - 96 MB Infinity Cache + "gfx1101": 64 * 1024 * 1024, # RX 7800 XT - 64 MB Infinity Cache + "gfx1102": 32 * 1024 * 1024, # RX 7600 - 32 MB Infinity Cache +} + +# Legacy alias for backwards compatibility +AMD_L2_CACHE_SIZES = AMD_LLC_CACHE_SIZES + +# Environment variable to override LLC cache size (in MB) +LLC_CACHE_OVERRIDE_ENV = "FLASH_ATTN_LLC_CACHE_MB" +L2_CACHE_OVERRIDE_ENV = "FLASH_ATTN_L2_CACHE_MB" # Legacy alias + +# Environment variable to disable head grouping +DISABLE_HEAD_GROUPING_ENV = "FLASH_ATTN_DISABLE_HEAD_GROUPING" + +# Cached LLC size per device +_llc_cache_size_cache: Dict[int, int] = {} + + +@functools.lru_cache(maxsize=None) +def get_gcn_arch_name(device_index: int = 0) -> str: + """Get the GCN architecture name for an AMD GPU.""" + try: + props = torch.cuda.get_device_properties(device_index) + if hasattr(props, 'gcnArchName'): + return props.gcnArchName + # Fallback: try to get from name + name = props.name.lower() + if 'gfx' in name: + # Extract gfxXXXX from name + import re + match = re.search(r'gfx\d+', name) + if match: + return match.group() + except Exception: + pass + return "unknown" + + +def get_num_cus(device_index: int = 0) -> int: + """ + Get the number of Compute Units for an AMD GPU. + + Note: PyTorch's multi_processor_count may be incorrect for some AMD GPUs. + We use known values for common architectures. + """ + arch = get_gcn_arch_name(device_index) + + # Known CU counts for common GPUs + known_cus = { + "gfx1100": 96, # RX 7900 XTX + "gfx1101": 60, # RX 7800 XT + "gfx1102": 32, # RX 7600 + } + + if arch in known_cus: + return known_cus[arch] + + # Fallback to PyTorch (may be incorrect) + try: + props = torch.cuda.get_device_properties(device_index) + return props.multi_processor_count + except Exception: + return 96 # Default + + +def get_llc_cache_size(device_index: int = 0) -> int: + """ + Get Infinity Cache (LLC) size for the specified GPU device. + + For RDNA3, this is the 96 MB Infinity Cache, not the 6 MB L2. + + Returns: + LLC cache size in bytes + """ + global _llc_cache_size_cache + + if device_index in _llc_cache_size_cache: + return _llc_cache_size_cache[device_index] + + # Check for environment override (new name first, then legacy) + for env_var in [LLC_CACHE_OVERRIDE_ENV, L2_CACHE_OVERRIDE_ENV]: + if env_var in os.environ: + try: + size_mb = int(os.environ[env_var]) + size_bytes = size_mb * 1024 * 1024 + _llc_cache_size_cache[device_index] = size_bytes + return size_bytes + except ValueError: + pass + + # Get architecture and look up cache size + arch = get_gcn_arch_name(device_index) + + # Check exact match first + if arch in AMD_LLC_CACHE_SIZES: + size = AMD_LLC_CACHE_SIZES[arch] + _llc_cache_size_cache[device_index] = size + return size + + # Check prefix match (e.g., gfx1100 matches gfx1100) + for known_arch, size in AMD_LLC_CACHE_SIZES.items(): + if arch.startswith(known_arch): + _llc_cache_size_cache[device_index] = size + return size + + # Default: assume 96 MB (conservative for RDNA3) + default_size = 96 * 1024 * 1024 + _llc_cache_size_cache[device_index] = default_size + return default_size + + +# Legacy alias +get_l2_cache_size = get_llc_cache_size + + +def calculate_optimal_head_group_size( + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0, + llc_utilization: float = 1.5 # Use 150% of LLC - optimal for long sequences +) -> int: + """ + Calculate the optimal number of heads to process together to fit K,V in LLC. + """ + llc_size = get_llc_cache_size(device_index) + + # Get element size in bytes + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 # Default to fp16 + + # Memory for K and V per head + kv_per_head = seqlen_k * head_dim * elem_size * 2 # *2 for K and V + + # Target LLC usage + target_llc = int(llc_size * llc_utilization) + + # Calculate number of heads that fit + if kv_per_head == 0: + return 1 + + head_group_size = max(1, target_llc // kv_per_head) + + return head_group_size + + +def is_head_grouping_beneficial( + nheads: int, + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0, + threshold_ratio: float = 1.5 +) -> Tuple[bool, int]: + """ + Determine if head grouping would be beneficial and return optimal group size. + """ + # Check if disabled via environment + if os.environ.get(DISABLE_HEAD_GROUPING_ENV, "0") == "1": + return False, nheads + + llc_size = get_llc_cache_size(device_index) + + # Get element size + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 + + # Total K,V memory for all heads + total_kv = nheads * seqlen_k * head_dim * elem_size * 2 + + # Only group if K,V significantly exceeds LLC + if total_kv < llc_size * threshold_ratio: + return False, nheads + + # Calculate optimal group size + group_size = calculate_optimal_head_group_size( + seqlen_k, head_dim, dtype, device_index + ) + + # Only group if we'd have at least 2 groups + if group_size >= nheads: + return False, nheads + + # Minimum group size to avoid excessive kernel launches + min_group_size = max(1, nheads // 16) # At most 16 groups + group_size = max(group_size, min_group_size) + + return True, min(group_size, nheads) + + +def print_head_grouping_info( + nheads: int, + seqlen_k: int, + head_dim: int, + dtype: torch.dtype, + device_index: int = 0 +): + """Print diagnostic information about head grouping.""" + llc_size = get_llc_cache_size(device_index) + arch = get_gcn_arch_name(device_index) + num_cus = get_num_cus(device_index) + + if dtype in (torch.float16, torch.bfloat16): + elem_size = 2 + elif dtype == torch.float32: + elem_size = 4 + elif 'float8' in str(dtype).lower(): + elem_size = 1 + else: + elem_size = 2 + + total_kv = nheads * seqlen_k * head_dim * elem_size * 2 + should_group, group_size = is_head_grouping_beneficial( + nheads, seqlen_k, head_dim, dtype, device_index + ) + + print(f"\n=== Infinity Cache (LLC) Aware Head Grouping ===") + print(f"GPU: {arch} ({num_cus} CUs)") + print(f"Infinity Cache (LLC): {llc_size / (1024*1024):.1f} MB") + print(f"Heads: {nheads}, SeqLen: {seqlen_k}, HeadDim: {head_dim}") + print(f"Total K,V Memory: {total_kv / (1024*1024):.1f} MB") + print(f"LLC Ratio: {total_kv / llc_size:.2f}x") + print(f"Should Group: {should_group}") + if should_group: + kv_per_group = group_size * seqlen_k * head_dim * elem_size * 2 + num_groups = (nheads + group_size - 1) // group_size + print(f"Group Size: {group_size} heads ({num_groups} groups)") + print(f"K,V per Group: {kv_per_group / (1024*1024):.1f} MB") + print("=" * 48 + "\n") diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 5d3bf02e1f8..167b99e0d81 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -48,7 +48,7 @@ class MetaData(): philox_seed: Optional[int] = None philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. - use_exp2: bool = False + use_exp2: bool = True rotary_sin: Optional[torch.Tensor] = None rotary_cos: Optional[torch.Tensor] = None rotary_interleaved: bool = False diff --git a/setup.py b/setup.py index f0b476255ba..9b1bd10088a 100644 --- a/setup.py +++ b/setup.py @@ -197,7 +197,7 @@ def rename_cpp_to_cu(cpp_files): def validate_and_update_archs(archs): # List of allowed architectures - allowed_archs = ["native", "gfx90a", "gfx950", "gfx942"] + allowed_archs = ["native", "gfx90a", "gfx950", "gfx942", "gfx1100"] # Validate if each element in archs is in allowed_archs assert all(