Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 102 additions & 8 deletions aiter/ops/triton/_triton_kernels/mha.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

from typing import Optional, Tuple
import functools
import json
import torch
Expand All @@ -10,9 +9,8 @@

from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton.pid_preprocessing import remap_xcd
from ..utils._triton.mha_kernel_utils import _compute_fp8_scaling_factors
from ..utils.device_info import get_num_xcds


@triton.jit
Expand Down Expand Up @@ -79,7 +77,9 @@ def _attn_fwd_inner(
l_i,
m_i,
q,
q_pe,
k_ptrs,
k_pe_ptrs,
v_ptrs,
stride_kn,
stride_vk,
Expand Down Expand Up @@ -107,6 +107,7 @@ def _attn_fwd_inner(
BLOCK_N: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DMODEL_POW2: tl.constexpr,
BLOCK_DMODEL_PE: tl.constexpr, # it's zero or a power of 2
SM_SCALE: tl.constexpr,
IS_CAUSAL: tl.constexpr,
MASK_STEPS: tl.constexpr,
Expand All @@ -115,12 +116,17 @@ def _attn_fwd_inner(
PADDED_HEAD: tl.constexpr,
IS_FP8: tl.constexpr,
FP8_MAX: tl.constexpr,
ENABLE_PIPELINING: tl.constexpr,
):
RCP_LN2: tl.constexpr = 1.4426950408889634
HAS_PE: tl.constexpr = BLOCK_DMODEL_PE > 0

# loop over k, v, and update accumulator

for start_n in range(block_min, block_max, BLOCK_N):
num_stages: tl.constexpr = (
None if ENABLE_PIPELINING else 1
) # Set num_stages==1 if we want to disable pipelining
for start_n in tl.range(block_min, block_max, BLOCK_N, num_stages=num_stages):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
if MASK_STEPS:
Expand All @@ -129,6 +135,14 @@ def _attn_fwd_inner(
k_offs_n = None
k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL_POW2)
k = _load_fn(k_ptrs, k_offs_k, k_offs_n, BLOCK_DMODEL, seqlen_k)
if HAS_PE:
k_pe = _load_fn(
k_pe_ptrs,
None,
k_offs_n,
(BLOCK_DMODEL + BLOCK_DMODEL_PE),
seqlen_k,
)

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need
Expand Down Expand Up @@ -163,11 +177,13 @@ def _attn_fwd_inner(
qk += tl.dot(q, k) * descale_q * descale_k
else:
qk += tl.dot(q, k)
if HAS_PE:
qk += tl.dot(q_pe, k_pe)

if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
mask = mask and causal_mask
mask = mask & causal_mask

qk = tl.where(mask, qk, float("-inf"))

Expand Down Expand Up @@ -229,6 +245,8 @@ def _attn_fwd_inner(
acc += tl.dot(p.to(v.type.element_ty), v)

k_ptrs += BLOCK_N * stride_kn
if HAS_PE:
k_pe_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
if RETURN_SCORES:
sd_mask_ptrs += BLOCK_N * stride_sn
Expand Down Expand Up @@ -287,15 +305,16 @@ def _attn_fwd(
dropout_p,
philox_seed,
philox_offset_base_in,
SEQLEN_Q: tl.constexpr,
SEQLEN_K: tl.constexpr,
SEQLEN_Q,
SEQLEN_K,
IS_CAUSAL: tl.constexpr,
NUM_Q_HEADS: tl.constexpr,
NUM_K_HEADS: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DMODEL_POW2: tl.constexpr,
BLOCK_DMODEL_PE: tl.constexpr, # it's zero or a power of 2
RETURN_SCORES: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
IS_FP8: tl.constexpr,
Expand All @@ -321,6 +340,9 @@ def _attn_fwd(
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL_POW2)
HAS_PE: tl.constexpr = BLOCK_DMODEL_PE > 0
if HAS_PE:
offs_pe = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL_PE)

# NOTE:
# Workaround for int64 strides, In the absence of strides being int64, parts of the offset
Expand Down Expand Up @@ -395,6 +417,38 @@ def _attn_fwd(
stride_lse_h = stride_lse_h_in
stride_lse_m = stride_lse_m_in

tl.assume(stride_qz_in >= 0)
tl.assume(stride_qh_in >= 0)
tl.assume(stride_qm_in >= 0)
tl.assume(stride_qk_in >= 0)
tl.assume(stride_kz_in >= 0)
tl.assume(stride_kh_in >= 0)
tl.assume(stride_kn_in >= 0)
tl.assume(stride_kk_in >= 0)
tl.assume(stride_vz_in >= 0)
tl.assume(stride_vh_in >= 0)
tl.assume(stride_vn_in >= 0)
tl.assume(stride_vk_in >= 0)
if IS_FP8:
tl.assume(stride_descale_q_z_in >= 0)
tl.assume(stride_descale_k_z_in >= 0)
tl.assume(stride_descale_v_z_in >= 0)
tl.assume(stride_oz_in >= 0)
tl.assume(stride_oh_in >= 0)
tl.assume(stride_om_in >= 0)
tl.assume(stride_on_in >= 0)
tl.assume(stride_alibi_z_in >= 0)
tl.assume(stride_alibi_h_in >= 0)
# NOTE: philox offset is need in dropout pointer calculations
tl.assume(philox_offset_base_in >= 0)
tl.assume(stride_sd_z_in >= 0)
tl.assume(stride_sd_h_in >= 0)
tl.assume(stride_sd_m_in >= 0)
tl.assume(stride_sd_n_in >= 0)
tl.assume(stride_lse_z_in >= 0)
tl.assume(stride_lse_h_in >= 0)
tl.assume(stride_lse_m_in >= 0)

if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
Expand Down Expand Up @@ -479,6 +533,17 @@ def _attn_fwd(
+ offs_d[None, :] * stride_qk
)
q_ptrs = q_ptr + q_offs
if HAS_PE:
q_pe_offs = (
off_z * stride_qz
+ off_q_head * stride_qh
+ cu_seqlens_q_start * stride_qm
+ offs_m[:, None] * stride_qm
+ offs_pe[None, :] * stride_qk
)
q_pe_ptrs = q_ptr + q_pe_offs
else:
q_pe_ptrs = None

k_offs = (
off_z * stride_kz
Expand All @@ -488,6 +553,17 @@ def _attn_fwd(
+ offs_n[None, :] * stride_kn
)
k_ptrs = k_ptr + k_offs
if HAS_PE:
k_pe_offs = (
off_z * stride_kz
+ off_k_head * stride_kh
+ cu_seqlens_k_start * stride_kn
+ offs_pe[:, None] * stride_kk
+ offs_n[None, :] * stride_kn
)
k_pe_ptrs = k_ptr + k_pe_offs
else:
k_pe_ptrs = None

v_offs = (
off_z * stride_vz
Expand Down Expand Up @@ -545,6 +621,11 @@ def _attn_fwd(
else:
q_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL)
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
if HAS_PE:
q_pe = tl.load(q_pe_ptrs, mask=q_mask, other=0.0)
else:
q_pe = None

if IS_FP8:
descale_q = tl.load(descale_q_ptr + off_z * stride_descale_q_z + off_q_head)
descale_k = tl.load(descale_k_ptr + off_z * stride_descale_k_z + off_k_head)
Expand Down Expand Up @@ -584,7 +665,9 @@ def _attn_fwd(
l_i,
m_i,
q,
q_pe,
k_ptrs,
k_pe_ptrs,
v_ptrs,
stride_kn,
stride_vn,
Expand Down Expand Up @@ -612,6 +695,7 @@ def _attn_fwd(
BLOCK_N,
BLOCK_DMODEL,
BLOCK_DMODEL_POW2,
BLOCK_DMODEL_PE,
sm_scale,
False,
MASK_STEPS=False,
Expand All @@ -620,6 +704,7 @@ def _attn_fwd(
PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2,
IS_FP8=IS_FP8,
FP8_MAX=FP8_MAX,
ENABLE_PIPELINING=True,
)
block_min = block_max
block_max = n_blocks * BLOCK_N
Expand All @@ -631,6 +716,8 @@ def _attn_fwd(
else:
offs_n_causal = 0
k_ptrs += n_full_blocks * BLOCK_N * stride_kn
if HAS_PE:
k_pe_ptrs += n_full_blocks * BLOCK_N * stride_kn
v_ptrs += n_full_blocks * BLOCK_N * stride_vn
if RETURN_SCORES:
s_dmask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n
Expand All @@ -641,7 +728,9 @@ def _attn_fwd(
l_i,
m_i,
q,
q_pe,
k_ptrs,
k_pe_ptrs,
v_ptrs,
stride_kn,
stride_vn,
Expand Down Expand Up @@ -669,6 +758,7 @@ def _attn_fwd(
BLOCK_N,
BLOCK_DMODEL,
BLOCK_DMODEL_POW2,
BLOCK_DMODEL_PE,
sm_scale,
IS_CAUSAL,
MASK_STEPS=True,
Expand All @@ -677,6 +767,7 @@ def _attn_fwd(
PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2,
IS_FP8=IS_FP8,
FP8_MAX=FP8_MAX,
ENABLE_PIPELINING=False,
)
# epilogue
# This helps the compiler do Newton Raphson on l_i vs on acc which is much larger.
Expand Down Expand Up @@ -759,6 +850,7 @@ def _attn_fwd(
def _get_config(
enable_dropout: bool,
dtype: torch.dtype,
has_pe: bool = False,
):
if not hasattr(_get_config, "_config_dict"):
dev = arch_info.get_device()
Expand All @@ -768,7 +860,9 @@ def _get_config(
config = json.load(file)
_get_config._config_dict["default"] = config

if enable_dropout or dtype == torch.float32:
if has_pe and "pe" in _get_config._config_dict["default"]["fwd"]:
return _get_config._config_dict["default"]["fwd"]["pe"]
elif enable_dropout or dtype == torch.float32:
return _get_config._config_dict["default"]["fwd"]["dropout_or_fp32"]
else:
return _get_config._config_dict["default"]["fwd"]["default"]
Loading
Loading