Skip to content
18 changes: 16 additions & 2 deletions aiter/ops/triton/_triton_kernels/chunked_pa_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,29 @@

import triton
import triton.language as tl
from ..utils._triton.kernel_repr import make_kernel_repr


@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y


@triton.jit
_kernel_paged_attention_2d_repr = make_kernel_repr(
"_kernel_paged_attention_2d",
[
"num_queries_per_kv",
"BLOCK_SIZE",
"HEAD_SIZE",
"USE_ALIBI_SLOPES",
"SLIDING_WINDOW",
"x",
"filter_by_query_len",
],
)


@triton.jit(repr=_kernel_paged_attention_2d_repr)
def _kernel_paged_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
Expand All @@ -31,7 +46,6 @@ def _kernel_paged_attention_2d(
scale, # float32
k_scale, # float32
v_scale, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
block_table_stride: tl.constexpr, # int
query_stride_0: tl.constexpr, # int
Expand Down
62 changes: 38 additions & 24 deletions aiter/ops/triton/_triton_kernels/hstu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Tuple
import json

# @manual=//triton:triton
Expand All @@ -22,9 +21,9 @@
# @manual=//triton:triton
import triton.language as tl
import functools
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton import arch_info
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.kernel_repr import make_kernel_repr

try:
from triton.language.extra.libdevice import (
Expand Down Expand Up @@ -315,7 +314,25 @@ def _hstu_attn_fwd_compute( # noqa C901
tl.store(out_ptrs, acc, mask=(offs_m < seq_len)[:, None])


@triton.jit
_hstu_attn_fwd_repr = make_kernel_repr(
"_hstu_attn_fwd",
[
"CAUSAL",
"HAS_MULTIPLE_TARGETS",
"IS_DELTA_Q",
"ALLOW_TF32",
"BLOCK_D_Q",
"BLOCK_D_V",
"BLOCK_M",
"BLOCK_N",
"HAS_CONTEXTUAL_SEQ_LEN",
"HAS_MAX_ATTN_LEN",
"HAS_SORT_BY_LENGTH_INDICES",
],
)


@triton.jit(repr=_hstu_attn_fwd_repr)
def _hstu_attn_fwd( # noqa C901
Q,
K,
Expand All @@ -333,13 +350,8 @@ def _hstu_attn_fwd( # noqa C901
stride_om,
stride_oh,
alpha,
Z,
AUTOTUNE_Z,
H,
MAX_SEQ_LEN,
AUTOTUNE_MAX_SEQ_LEN, # Quantized MAX_SEQ_LEN used as an autotuning key
DimQ,
DimV,
DeltaSize,
contextual_seq_len,
max_attn_len,
Expand Down Expand Up @@ -693,7 +705,24 @@ def _hstu_attn_bwd_one_col_block( # noqa C901
tl.store(dk_ptrs, dk.to(k.dtype), mask=mask_n[:, None])


@triton.jit
_hstu_attn_bwd_repr = make_kernel_repr(
"_hstu_attn_bwd",
[
"CAUSAL",
"HAS_MULTIPLE_TARGETS",
"ALLOW_TF32",
"BLOCK_D_Q",
"BLOCK_D_V",
"BLOCK_M",
"BLOCK_N",
"HAS_CONTEXTUAL_SEQ_LEN",
"HAS_MAX_ATTN_LEN",
"HAS_SORT_BY_LENGTH_INDICES",
],
)


@triton.jit(repr=_hstu_attn_bwd_repr)
def _hstu_attn_bwd( # noqa C901
Q,
K,
Expand Down Expand Up @@ -723,13 +752,8 @@ def _hstu_attn_bwd( # noqa C901
alpha,
contextual_seq_len,
max_attn_len,
Z,
AUTOTUNE_Z,
H,
MAX_SEQ_LEN,
AUTOTUNE_MAX_SEQ_LEN, # Quantized MAX_SEQ_LEN used as an autotuning key
DimQ,
DimV,
CAUSAL: tl.constexpr,
HAS_MULTIPLE_TARGETS: tl.constexpr,
HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr,
Expand Down Expand Up @@ -845,12 +869,6 @@ def _hstu_attn_bwd( # noqa C901
@functools.lru_cache(maxsize=1024)
def _get_fwd_config(
AUTOTUNE_Z: int,
H: int,
AUTOTUNE_MAX_SEQ_LEN: int,
DimQ: int,
DimV: int,
DeltaSize: int,
IS_DELTA_Q: bool,
):
if not hasattr(_get_fwd_config, "_config_dict"):
dev = arch_info.get_device()
Expand All @@ -872,10 +890,6 @@ def _get_fwd_config(
@functools.lru_cache(maxsize=1024)
def _get_bwd_config(
AUTOTUNE_Z: int,
H: int,
AUTOTUNE_MAX_SEQ_LEN: int,
DimQ: int,
DimV: int,
):
if not hasattr(_get_bwd_config, "_config_dict"):
dev = arch_info.get_device()
Expand Down
31 changes: 27 additions & 4 deletions aiter/ops/triton/_triton_kernels/lean_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
import json
import triton
import triton.language as tl
from typing import Optional
from bisect import bisect_right
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton.pid_preprocessing import remap_xcd
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.kernel_repr import make_kernel_repr


# Support tensor in [B, Seqlen, H, d] format. Taking tensors in [B*Seqlen, H, d] as inputs
Expand Down Expand Up @@ -209,7 +208,31 @@ def remap_xcd(pid, GRID_MN: tl.constexpr, NUM_XCDS: tl.constexpr = 8):
return pid, pids_per_xcd


@triton.jit
_la_persistent_repr = make_kernel_repr(
"la_persistent",
[
"HEADS_PER_XCD",
"HEAD_DIM",
"BLOCK_M",
"BLOCK_N",
"MASKED_BLOCKS",
"XCD_REMAP",
"NUM_XCDS",
"batch_size",
"causal",
"num_m_blocks",
"num_n_blocks",
"total_programs",
"high_load_wgs",
"max_tiles_per_wg",
"tiles_per_head",
"num_splits",
"max_output_tile_cnt",
],
)


@triton.jit(repr=_la_persistent_repr)
def la_persistent(
is_pod,
pod_pid,
Expand Down
22 changes: 21 additions & 1 deletion aiter/ops/triton/_triton_kernels/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.pid_preprocessing import remap_xcd
from ..utils._triton.mha_kernel_utils import _compute_fp8_scaling_factors
from ..utils._triton.kernel_repr import make_kernel_repr


@triton.jit
Expand Down Expand Up @@ -255,7 +256,26 @@ def _attn_fwd_inner(
return acc, l_i, m_i


@triton.jit
_attn_fwd_repr = make_kernel_repr(
"_attn_fwd",
[
"IS_CAUSAL",
"NUM_Q_HEADS",
"NUM_K_HEADS",
"BLOCK_M",
"BLOCK_N",
"BLOCK_DMODEL",
"RETURN_SCORES",
"ENABLE_DROPOUT",
"IS_FP8",
"VARLEN",
"NUM_XCD",
"USE_INT64_STRIDES",
],
)


@triton.jit(repr=_attn_fwd_repr)
def _attn_fwd(
q_ptr: torch.Tensor,
k_ptr: torch.Tensor,
Expand Down
58 changes: 50 additions & 8 deletions aiter/ops/triton/_triton_kernels/mha_fused_bwd.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

from typing import Optional, Dict
import functools
import json
import torch
import triton
import triton.language as tl


from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton.pid_preprocessing import remap_xcd
from ..utils._triton.mha_kernel_utils import _compute_fp8_scaling_factors
from ..utils._triton.kernel_repr import make_kernel_repr


# This function computes delta given output Out and gradient DO
# Here is the I/O shape:
# Out: (batch, nhead_q, max_seqlens_q, headDim)
# DO: (batch, nhead_q, max_seqlens_q, headDim)
# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at
@triton.jit
_bwd_preprocess_repr = make_kernel_repr(
"_bwd_preprocess",
[
"BLOCK_M",
"BLOCK_D_MODEL",
"IS_VARLEN",
"IS_FP8",
],
)


@triton.jit(repr=_bwd_preprocess_repr)
def _bwd_preprocess(
o_ptr,
do_ptr, # noqa: E741
Expand Down Expand Up @@ -312,7 +322,25 @@ def _bwd_dkdvdq_inner(
return dk, dv


@triton.jit
_bwd_kernel_dkdvdq_causal_repr = make_kernel_repr(
"_bwd_kernel_dkdvdq_causal",
[
"NUM_Q_HEADS",
"NUM_K_HEADS",
"BLOCK_M",
"BLOCK_N",
"BLK_SLICE_FACTOR",
"BLOCK_D_MODEL",
"ENABLE_DROPOUT",
"IS_VARLEN",
"IS_FP8",
"USE_INT64_STRIDES",
"NUM_XCD",
],
)


@triton.jit(repr=_bwd_kernel_dkdvdq_causal_repr)
def _bwd_kernel_dkdvdq_causal(
q_ptr,
k_ptr,
Expand Down Expand Up @@ -384,7 +412,6 @@ def _bwd_kernel_dkdvdq_causal(
IS_VARLEN: tl.constexpr,
IS_FP8: tl.constexpr,
FP8_MAX: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_INT64_STRIDES: tl.constexpr,
NUM_XCD: tl.constexpr,
):
Expand Down Expand Up @@ -714,7 +741,23 @@ def _bwd_kernel_dkdvdq_causal(
tl.atomic_add(dk_ptr + offs_dkdv, dk, mask=mask_kv, sem="relaxed")


@triton.jit
_bwd_kernel_dkdvdq_noncausal_repr = make_kernel_repr(
"_bwd_kernel_dkdvdq_noncausal",
[
"NUM_Q_HEADS",
"NUM_K_HEADS",
"BLOCK_M",
"BLOCK_N",
"BLOCK_D_MODEL",
"ENABLE_DROPOUT",
"IS_VARLEN",
"IS_FP8",
"USE_INT64_STRIDES",
],
)


@triton.jit(repr=_bwd_kernel_dkdvdq_noncausal_repr)
def _bwd_kernel_dkdvdq_noncausal(
Q,
K,
Expand Down Expand Up @@ -786,7 +829,6 @@ def _bwd_kernel_dkdvdq_noncausal(
IS_VARLEN: tl.constexpr,
IS_FP8: tl.constexpr,
FP8_MAX: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_INT64_STRIDES: tl.constexpr,
):
if USE_INT64_STRIDES:
Expand Down
Loading
Loading