Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 238 additions & 1 deletion aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
is_fp8,
remap_xcd,
)
from .llc_cache_aware import is_head_grouping_beneficial

FWD_PREFILL_AUTOTUNE_KEYS = [
"IS_CAUSAL",
Expand Down Expand Up @@ -1398,7 +1399,7 @@ def attn_fwd(
tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask)


def attention_forward_prefill_triton_impl(
def _attention_forward_prefill_triton_impl_core(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
Expand Down Expand Up @@ -1873,3 +1874,239 @@ def grid(META):
FORCE_MASKING=force_masking,
NUM_XCD=num_xcd,
)


def attention_forward_prefill_triton_impl(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
softmax_lse: torch.Tensor,
sd_mask: Optional[torch.Tensor],
sm_scale: float,
alibi_slopes: Optional[torch.Tensor],
causal: bool,
window_size_left: int,
window_size_right: int,
bias: Optional[torch.Tensor],
layout: Literal["bshd", "bhsd", "thd"],
# varlen
cu_seqlens_q: Optional[torch.Tensor],
cu_seqlens_k: Optional[torch.Tensor],
max_seqlens_q: int,
max_seqlens_k: int,
# dropout
dropout_p: float,
philox_seed: Optional[int],
philox_offset: Optional[int],
# misc
return_scores: bool,
use_exp2: bool,
# fp8
q_descale: Optional[torch.Tensor],
k_descale: Optional[torch.Tensor],
v_descale: Optional[torch.Tensor],
# seqused for FA v3
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
# rotary (optional)
rotary_cos: Optional[torch.Tensor] = None,
rotary_sin: Optional[torch.Tensor] = None,
rotary_interleaved: bool = False,
seqlens_rotary: Optional[torch.Tensor] = None,
):
"""
Wrapper for attention forward with LLC-aware head grouping optimization.

For long sequences on GPUs with large LLC (e.g., RDNA3 with 96 MB Infinity Cache),
processing heads in groups that fit K,V in cache can significantly improve performance.
"""
IS_VARLEN = layout == "thd"

# Get head dimensions
if IS_VARLEN:
total_q, nheads_q, head_dim = q.shape
nheads_k = k.shape[1]
else:
batch, seqlen_q, nheads_q, head_dim = q.shape
nheads_k = k.shape[2]

# Check if head grouping is beneficial
should_group, group_size = is_head_grouping_beneficial(
nheads_k, max_seqlens_k, head_dim, q.dtype, q.device.index or 0
)

# Disable head grouping if return_scores is requested (need full attention matrix)
# or if sd_mask is provided
if return_scores or sd_mask is not None:
should_group = False

if not should_group or group_size >= nheads_q:
# No grouping needed - call core implementation directly
return _attention_forward_prefill_triton_impl_core(
q,
k,
v,
o,
softmax_lse,
sd_mask,
sm_scale,
alibi_slopes,
causal,
window_size_left,
window_size_right,
bias,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
dropout_p,
philox_seed,
philox_offset,
return_scores,
use_exp2,
q_descale,
k_descale,
v_descale,
seqused_q,
seqused_k,
rotary_cos,
rotary_sin,
rotary_interleaved,
seqlens_rotary,
)

# Head grouping path
if DEBUG:
print(
f"[LLC Head Grouping fwd_prefill] Processing {nheads_q} heads in groups of {group_size}"
)

gqa_ratio = nheads_q // nheads_k
n_groups = (nheads_q + group_size - 1) // group_size

# Calculate K,V heads per group (for GQA)
group_size_k = (group_size + gqa_ratio - 1) // gqa_ratio

# Pre-allocate K,V buffers to avoid repeated allocations in loop
# This reuses memory across iterations instead of calling .contiguous() each time
if IS_VARLEN:
# thd layout: (total_tokens, nheads_k_group, head_dim)
k_buffer = torch.empty(
(total_q, group_size_k, head_dim), device=k.device, dtype=k.dtype
)
v_buffer = torch.empty(
(total_q, group_size_k, head_dim), device=v.device, dtype=v.dtype
)
else:
# bshd layout: (batch, seqlen_k, nheads_k_group, head_dim)
seqlen_k = k.shape[1]
k_buffer = torch.empty(
(batch, seqlen_k, group_size_k, head_dim), device=k.device, dtype=k.dtype
)
v_buffer = torch.empty(
(batch, seqlen_k, group_size_k, head_dim), device=v.device, dtype=v.dtype
)

softmax_lse_list = []

for g in range(n_groups):
start_h = g * group_size
end_h = min((g + 1) * group_size, nheads_q)
actual_heads = end_h - start_h

# For GQA, compute corresponding K,V head range
start_h_k = start_h // gqa_ratio
end_h_k = (end_h + gqa_ratio - 1) // gqa_ratio
actual_heads_k = end_h_k - start_h_k

if IS_VARLEN:
# thd layout: (total_tokens, nheads, head_dim)
q_group = q[:, start_h:end_h, :] # strided view
o_group = o[:, start_h:end_h, :] # strided view, write directly

# Copy K,V into pre-allocated buffers
k_group = k_buffer[:, :actual_heads_k, :]
v_group = v_buffer[:, :actual_heads_k, :]
k_group.copy_(k[:, start_h_k:end_h_k, :])
v_group.copy_(v[:, start_h_k:end_h_k, :])

# softmax_lse for varlen: (Hq, Total_Q)
softmax_lse_group = torch.zeros(
(actual_heads, total_q), device=q.device, dtype=torch.float32
)
else:
# bshd layout: (batch, seqlen, nheads, head_dim)
q_group = q[:, :, start_h:end_h, :] # strided view
o_group = o[:, :, start_h:end_h, :] # strided view, write directly

# Copy K,V into pre-allocated buffers
k_group = k_buffer[:, :, :actual_heads_k, :]
v_group = v_buffer[:, :, :actual_heads_k, :]
k_group.copy_(k[:, :, start_h_k:end_h_k, :])
v_group.copy_(v[:, :, start_h_k:end_h_k, :])

# softmax_lse for bshd: (B, Hq, Sq)
softmax_lse_group = torch.zeros(
(batch, actual_heads, softmax_lse.shape[-1]),
device=q.device,
dtype=torch.float32,
)

# Handle alibi slopes if present
alibi_group = None
if alibi_slopes is not None:
alibi_group = (
alibi_slopes[:, start_h:end_h]
if alibi_slopes.dim() == 2
else alibi_slopes[start_h:end_h]
)

# Call core implementation for this group
_attention_forward_prefill_triton_impl_core(
q_group,
k_group,
v_group,
o_group,
softmax_lse_group,
None,
sm_scale,
alibi_group,
causal,
window_size_left,
window_size_right,
bias,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
dropout_p,
philox_seed,
philox_offset,
False,
use_exp2,
q_descale,
k_descale,
v_descale,
seqused_q,
seqused_k,
rotary_cos,
rotary_sin,
rotary_interleaved,
seqlens_rotary,
)

softmax_lse_list.append(softmax_lse_group)

# Concatenate softmax_lse across heads
if IS_VARLEN:
# varlen: (Hq, Total_Q) - concat on dim 0
final_lse = torch.cat(softmax_lse_list, dim=0)
else:
# bshd: (B, Hq, Sq) - concat on dim 1
final_lse = torch.cat(softmax_lse_list, dim=1)

# Copy back to caller's softmax_lse tensor
softmax_lse.copy_(final_lse)
Loading
Loading