Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion flash_attn/flash_attn_triton_amd/fwd_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import triton
import triton.language as tl
from typing import Literal, Optional, Union
from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna
from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna, is_rdna

def get_cdna_autotune_configs():
return [
Expand All @@ -23,13 +23,38 @@ def get_cdna_autotune_configs():
num_warps=4),
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']

def get_rdna_autotune_configs():
return [
# Most aggressive - 128x128 (best for large sequences)
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
# Large blocks
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8),
# Medium blocks
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2),
# Fall-back config.
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK']

def get_autotune_configs():
if AUTOTUNE:
if is_cdna():
autotune_configs, autotune_keys = get_cdna_autotune_configs()
fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys
reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys
return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys)
elif is_rdna():
autotune_configs, autotune_keys = get_rdna_autotune_configs()
fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys
reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys
return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys)
else:
raise ValueError("Unknown Device Type")
else:
Expand Down
31 changes: 15 additions & 16 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,24 +186,22 @@ def get_cdna_autotune_configs():

def get_rdna_autotune_configs():
return [
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
# Fall-back config.
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
# Best config from autotune on gfx1100: 32x16, warps=2, PRE_LOAD_V=True
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=1),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=1),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2),
# === Configs for head_dim=128 ===
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=2),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4),
# === Fallback configs ===
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK']



def get_autotune_configs():
if AUTOTUNE:
if is_rdna():
Expand All @@ -214,8 +212,9 @@ def get_autotune_configs():
raise ValueError("Unknown Device Type")
else:
return [
# Optimized for gfx1100 (RDNA3) with LLC-aware head grouping
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False},
{"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": True},
num_stages=1,
num_warps=4,
),
Expand Down
233 changes: 231 additions & 2 deletions flash_attn/flash_attn_triton_amd/interface_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
from .fwd_ref import attention_forward_pytorch_ref_impl
from .bwd_ref import attention_backward_pytorch_ref_impl
from .utils import DEBUG, USE_REF, MetaData, get_shapes_from_layout, is_fp8
from .l2_cache_aware import is_head_grouping_beneficial, print_head_grouping_info
from einops import rearrange, repeat
from flash_attn.layers.rotary import apply_rotary_emb
from typing import Literal, Optional, Union

def fwd(q: torch.Tensor,
# Environment variable to enable verbose head grouping output
L2_HEAD_GROUPING_DEBUG = os.environ.get('FLASH_ATTN_HEAD_GROUPING_DEBUG', '0') == '1'



def _fwd_single_group(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: Optional[torch.Tensor],
Expand All @@ -31,6 +37,7 @@ def fwd(q: torch.Tensor,
descale_v: Optional[torch.Tensor] = None,
descale_o: Optional[torch.Tensor] = None
):
"""Original fwd implementation for a single head group."""

if DEBUG:
print()
Expand Down Expand Up @@ -145,6 +152,112 @@ def fwd(q: torch.Tensor,

return out, softmax_lse, sd_mask, rng_state


def fwd(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: Optional[torch.Tensor],
alibi_slopes: Optional[torch.Tensor],
dropout_p: float,
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
softcap: float,
return_softmax: bool,
gen_: Optional[torch.Tensor] = None,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_o: Optional[torch.Tensor] = None
):
"""
Flash attention forward with L2 cache-aware head grouping.

For consumer AMD GPUs (e.g., gfx1100 with 96MB L2), when K,V tensors
exceed L2 cache capacity, processing heads in groups that fit in L2
can provide up to 2x speedup.

Layout: bshd (batch, seqlen, heads, head_dim)
"""
# Get shapes for head grouping decision
# Layout is bshd: [batch, seqlen, heads, head_dim]
batch, seqlen_q, nheads_q, head_dim = q.shape
seqlen_k = k.shape[1]
nheads_k = k.shape[2]

# Check if head grouping is beneficial
should_group, group_size = is_head_grouping_beneficial(
nheads_q, seqlen_k, head_dim, q.dtype, q.device.index or 0
)

if L2_HEAD_GROUPING_DEBUG:
print_head_grouping_info(nheads_q, seqlen_k, head_dim, q.dtype, q.device.index or 0)

if not should_group or group_size >= nheads_q:
# No grouping needed - use original implementation
return _fwd_single_group(
q, k, v, out, alibi_slopes, dropout_p, softmax_scale, causal,
window_size_left, window_size_right, softcap, return_softmax,
gen_, descale_q, descale_k, descale_v, descale_o
)

# Process heads in groups for L2 cache efficiency
if L2_HEAD_GROUPING_DEBUG:
print(f"[L2 Head Grouping] Processing {nheads_q} heads in groups of {group_size}")

# Prepare output tensor
if is_fp8(q):
assert out is not None, "fp8 output tensor should be passed in."
else:
out = torch.zeros_like(q) if out is None else out.zero_()

# Collect outputs for each group
softmax_lse_list = []
rng_state = None

n_groups = (nheads_q + group_size - 1) // group_size

for g in range(n_groups):
start_h = g * group_size
end_h = min((g + 1) * group_size, nheads_q)

# Slice heads: bshd layout -> select heads on dim 2
q_group = q[:, :, start_h:end_h, :].contiguous()
k_group = k[:, :, start_h:end_h, :].contiguous()
v_group = v[:, :, start_h:end_h, :].contiguous()
out_group = out[:, :, start_h:end_h, :].contiguous()

# Handle alibi slopes if present
alibi_group = None
if alibi_slopes is not None:
alibi_group = alibi_slopes[start_h:end_h] if alibi_slopes.dim() == 1 else alibi_slopes[:, start_h:end_h]

# Handle descale tensors for fp8
descale_q_g = descale_q[:, start_h:end_h] if descale_q is not None else None
descale_k_g = descale_k[:, start_h:end_h] if descale_k is not None else None
descale_v_g = descale_v[:, start_h:end_h] if descale_v is not None else None
descale_o_g = descale_o[:, start_h:end_h] if descale_o is not None else None

# Call the original implementation for this group
out_g, softmax_lse_g, sd_mask_g, rng_state = _fwd_single_group(
q_group, k_group, v_group, out_group, alibi_group,
dropout_p, softmax_scale, causal,
window_size_left, window_size_right, softcap, return_softmax,
gen_, descale_q_g, descale_k_g, descale_v_g, descale_o_g
)

# Copy output back to the main tensor
out[:, :, start_h:end_h, :] = out_g
softmax_lse_list.append(softmax_lse_g)

# Concatenate softmax_lse across heads
softmax_lse = torch.cat(softmax_lse_list, dim=1) # Assuming lse is [batch, heads, ...]

return out, softmax_lse, None, rng_state



BWD_MODE = os.environ.get('BWD_MODE', 'split').lower()
def bwd(
dout: torch.Tensor,
Expand Down Expand Up @@ -349,7 +462,7 @@ def bwd(
print("descale_dq:", descale_dq, descale_dq.shape if descale_dq is not None else None)
return dq, dk, dv, delta

def varlen_fwd(
def _varlen_fwd_single_group(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
Expand All @@ -376,6 +489,7 @@ def varlen_fwd(
descale_v: Optional[torch.Tensor] = None,
descale_o: Optional[torch.Tensor] = None
):
"""Original varlen_fwd implementation for a single head group."""

if DEBUG:
print()
Expand Down Expand Up @@ -490,6 +604,121 @@ def varlen_fwd(

return out, softmax_lse, sd_mask, rng_state


def varlen_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: Optional[torch.Tensor],
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
seqused_k: Optional[torch.Tensor],
leftpad_k: Optional[torch.Tensor],
block_table_: Optional[torch.Tensor],
alibi_slopes: Optional[torch.Tensor],
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float,
softmax_scale: float,
zero_tensors: bool ,
causal: bool ,
window_size_left: int,
window_size_right: int,
softcap: float,
return_softmax: bool,
gen_: Optional[torch.Tensor] = None,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
descale_v: Optional[torch.Tensor] = None,
descale_o: Optional[torch.Tensor] = None
):
"""
Variable-length flash attention forward with L2 cache-aware head grouping.

For consumer AMD GPUs (e.g., gfx1100 with 96MB L2), when K,V tensors
exceed L2 cache capacity, processing heads in groups that fit in L2
can provide up to 2x speedup.

Layout: thd (total_seqlen, heads, head_dim)
"""
# Get shapes for head grouping decision
# Layout is thd: [total_seqlen, heads, head_dim]
total_seqlen, nheads_q, head_dim = q.shape
nheads_k = k.shape[1]

# Check if head grouping is beneficial
should_group, group_size = is_head_grouping_beneficial(
nheads_q, max_seqlen_k, head_dim, q.dtype, q.device.index or 0
)

if L2_HEAD_GROUPING_DEBUG:
print_head_grouping_info(nheads_q, max_seqlen_k, head_dim, q.dtype, q.device.index or 0)

if not should_group or group_size >= nheads_q:
# No grouping needed - use original implementation
return _varlen_fwd_single_group(
q, k, v, out, cu_seqlens_q, cu_seqlens_k, seqused_k, leftpad_k,
block_table_, alibi_slopes, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, zero_tensors, causal,
window_size_left, window_size_right, softcap, return_softmax,
gen_, descale_q, descale_k, descale_v, descale_o
)

# Process heads in groups for L2 cache efficiency
if L2_HEAD_GROUPING_DEBUG:
print(f"[L2 Head Grouping varlen] Processing {nheads_q} heads in groups of {group_size}")

# Prepare output tensor
if is_fp8(q):
assert out is not None, "fp8 output tensor should be passed in."
else:
out = torch.zeros_like(q) if out is None else out.zero_()

# Collect outputs for each group
softmax_lse_list = []
rng_state = None

n_groups = (nheads_q + group_size - 1) // group_size

for g in range(n_groups):
start_h = g * group_size
end_h = min((g + 1) * group_size, nheads_q)

# Slice heads: thd layout -> select heads on dim 1
q_group = q[:, start_h:end_h, :].contiguous()
k_group = k[:, start_h:end_h, :].contiguous()
v_group = v[:, start_h:end_h, :].contiguous()
out_group = out[:, start_h:end_h, :].contiguous()

# Handle alibi slopes if present
alibi_group = None
if alibi_slopes is not None:
alibi_group = alibi_slopes[start_h:end_h] if alibi_slopes.dim() == 1 else alibi_slopes[:, start_h:end_h]

# Handle descale tensors for fp8
descale_q_g = descale_q[:, start_h:end_h] if descale_q is not None else None
descale_k_g = descale_k[:, start_h:end_h] if descale_k is not None else None
descale_v_g = descale_v[:, start_h:end_h] if descale_v is not None else None
descale_o_g = descale_o[:, start_h:end_h] if descale_o is not None else None

# Call the original implementation for this group
out_g, softmax_lse_g, sd_mask_g, rng_state = _varlen_fwd_single_group(
q_group, k_group, v_group, out_group, cu_seqlens_q, cu_seqlens_k,
seqused_k, leftpad_k, block_table_, alibi_group,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale,
zero_tensors, causal, window_size_left, window_size_right,
softcap, return_softmax, gen_, descale_q_g, descale_k_g, descale_v_g, descale_o_g
)

# Copy output back to the main tensor
out[:, start_h:end_h, :] = out_g
softmax_lse_list.append(softmax_lse_g)

# Concatenate softmax_lse across heads
softmax_lse = torch.cat(softmax_lse_list, dim=0) # varlen lse is [heads, total_seqlen]

return out, softmax_lse, None, rng_state

def varlen_bwd(
dout: torch.Tensor,
q: torch.Tensor,
Expand Down
Loading