Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aiter/ops/triton/_triton_kernels/pa_mqa_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def _gluon_deepgemm_fp8_paged_mqa_logits(
max_model_len,
max_block_len,
SplitKV,
dummyPointerArg,
dummyPointerArg, # dummy pointer for compatibility with triton3.5 on lower version
ChunkQ: tl.constexpr,
ChunkK: tl.constexpr,
HiddenDim: tl.constexpr,
Expand Down Expand Up @@ -449,7 +449,7 @@ def _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle(
max_model_len,
max_block_len,
SplitKV,
dummyPointerArg,
dummyPointerArg, # dummy pointer for compatibility with triton3.5 on lower version
ChunkQ: tl.constexpr,
ChunkK: tl.constexpr,
HiddenDim: tl.constexpr,
Expand Down
43 changes: 26 additions & 17 deletions aiter/ops/triton/pa_mqa_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,53 +21,54 @@
# ========================================================================

import os
import torch
import triton
from functools import lru_cache

import torch
import triton
from packaging.version import Version
from triton.backends.compiler import GPUTarget

enable_aot_gluon_pa_mqa_logits = os.environ.get(
"AITER_ENABLE_AOT_GLUON_PA_MQA_LOGITS", "0"
)
enable_aot_gluon_pa_mqa_logits = enable_aot_gluon_pa_mqa_logits == "1"

if triton.__version__ >= "3.5.0":
triton_version = Version(Version(triton.__version__).base_version)
if triton_version >= Version("3.5.0"):
from triton.experimental.gluon._runtime import GluonASTSource as ASTSource

from aiter.ops.triton._triton_kernels.pa_mqa_logits import (
_deepgemm_fp8_paged_mqa_logits_stage1,
_deepgemm_fp8_paged_mqa_logits_stage1_ragged_k,
_deepgemm_fp8_paged_mqa_logits,
_deepgemm_fp8_paged_mqa_logits_ragged_k,
_deepgemm_fp8_paged_mqa_logits_stage1,
_deepgemm_fp8_paged_mqa_logits_stage1_ragged_k,
)
from aiter.ops.triton.gluon.pa_mqa_logits import (
_gluon_deepgemm_fp8_paged_mqa_logits,
_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle,
)

enable_gluon_pa_mqa_logits = True
enable_jit_gluon_pa_mqa_logits_kernel = True
enable_jit_gluon_pa_mqa_logits_kernel = not enable_aot_gluon_pa_mqa_logits
else:
from triton.compiler import ASTSource

from aiter.ops.triton._triton_kernels.pa_mqa_logits import (
_deepgemm_fp8_paged_mqa_logits_stage1,
_deepgemm_fp8_paged_mqa_logits_stage1_ragged_k,
_deepgemm_fp8_paged_mqa_logits,
_deepgemm_fp8_paged_mqa_logits_ragged_k,
_deepgemm_fp8_paged_mqa_logits_stage1,
_deepgemm_fp8_paged_mqa_logits_stage1_ragged_k,
_gluon_deepgemm_fp8_paged_mqa_logits,
_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle,
)

assert triton.__version__ < "3.4.0"
enable_gluon_pa_mqa_logits = enable_aot_gluon_pa_mqa_logits
enable_jit_gluon_pa_mqa_logits_kernel = False


from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
from aiter.utility.triton.triton_metadata_redirect import (
AOTMetadataContext,
)
from aiter import dtypes
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
from aiter.utility.triton.triton_metadata_redirect import AOTMetadataContext

from ...jit.utils.chip_info import get_gfx


Expand Down Expand Up @@ -273,7 +274,8 @@ def _compile_deepgemm_fp8_paged_mqa_logits(
"max_block_len": "i32",
"SplitKV": "i32",
}
if not enable_jit_gluon_pa_mqa_logits_kernel:
if triton_version < Version("3.4.0"):
assert not enable_jit_gluon_pa_mqa_logits_kernel
fn_signature["dummyPointerArg"] = "*i32"
fn_signature["ChunkQ"] = "constexpr"
fn_signature["ChunkK"] = "constexpr"
Expand Down Expand Up @@ -405,7 +407,7 @@ def deepgemm_fp8_paged_mqa_logits(
is_padded_mode=is_padded_mode,
WavePerEU=WavePerEU,
)
if enable_jit_gluon_pa_mqa_logits_kernel:
if triton_version >= Version("3.5.0"):
kernel[grid](
batch_size,
next_n,
Expand Down Expand Up @@ -434,6 +436,13 @@ def deepgemm_fp8_paged_mqa_logits(
hidden_dim,
)
else: # load AOT compiled gluon kernel
assert triton_version < Version(
"3.4.0"
), "https://github.com/triton-lang/triton/pull/7258 involves a ABI-breaking change on triton3.4, "
"which adding an extra pointer argument at the end of kernel arguments. To ensure compatibility"
"with AOT compiled gluon kernel on triton3.5, a feasible solution is to add a pointer parameter "
"at the end of the parameters and ensure that the Triton version used is before the ABI "
"modification, i.e., verison<3.4.0"
kernel[grid](
batch_size,
next_n,
Expand All @@ -456,14 +465,14 @@ def deepgemm_fp8_paged_mqa_logits(
max_block_len,
SplitKV,
out_logits, # dummyPointerArg for triton version < 3.4.0,
# the kernel signature has an extra pointer argument on triton>=3.5.0
# constexpr
heads,
ChunkK,
KVBlockSize,
hidden_dim,
)
else:
assert KVBlockSize == 1
assert not Preshuffle, "Preshuffle mode is only supported on gluon kernel."
kernel = _deepgemm_fp8_paged_mqa_logits[grid](
batch_size,
Expand Down
Loading