Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion aiter/ops/triton/_triton_kernels/gemm_a16w16.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@
)


_gemm_a16w16_reduce_repr = make_kernel_repr(
"_gemm_a16w16_reduce_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"ACTUAL_KSPLIT",
"MAX_KSPLIT",
"activation",
"use_activation",
"ADD_BIAS",
],
)


@triton.heuristics(
{
"EVEN_K": lambda args: (args["K"] % (args["SPLITK_BLOCK_SIZE"]) == 0)
Expand Down Expand Up @@ -169,7 +183,7 @@ def _gemm_a16_w16_kernel(
tl.store(c_ptrs, c, mask=c_mask)


@triton.jit
@triton.jit(repr=_gemm_a16w16_reduce_repr)
def _gemm_a16w16_reduce_kernel(
bias_ptr,
c_in_ptr,
Expand Down
19 changes: 18 additions & 1 deletion aiter/ops/triton/_triton_kernels/gemm_a16w16_atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.kernel_repr import make_kernel_repr


_gemm_a16w16_atomic_repr = make_kernel_repr(
"_gemm_a16_w16_atomic_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"NUM_KSPLIT",
"SPLITK_BLOCK_SIZE",
"cache_modifier",
"EVEN_K",
"GRID_MN",
],
)


@triton.heuristics(
Expand All @@ -21,7 +38,7 @@
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
@triton.jit(repr=_gemm_a16w16_atomic_repr)
def _gemm_a16_w16_atomic_kernel(
a_ptr,
b_ptr,
Expand Down
19 changes: 18 additions & 1 deletion aiter/ops/triton/_triton_kernels/gemm_a16w16_gated.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from .activation import _get_activation_from_str
from ..utils._triton.kernel_repr import make_kernel_repr


_gemm_a16w16_gated_repr = make_kernel_repr(
"_gemm_a16_w16_gated_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"EVEN_K",
"GRID_MN",
"cache_modifier",
"activation",
"use_activation",
],
)


@triton.heuristics(
Expand All @@ -19,7 +36,7 @@
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
@triton.jit(repr=_gemm_a16w16_gated_repr)
def _gemm_a16_w16_gated_kernel(
a_ptr,
b_ptr,
Expand Down
18 changes: 17 additions & 1 deletion aiter/ops/triton/_triton_kernels/gemm_a8w8.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
import triton.language as tl
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.kernel_repr import make_kernel_repr


_gemm_a8w8_repr = make_kernel_repr(
"_gemm_a8w8_kernel",
[
"HAS_BIAS",
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"EVEN_K",
"GRID_MN",
"NUM_XCDS",
],
)


@triton.heuristics(
Expand All @@ -17,7 +33,7 @@
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
@triton.jit(repr=_gemm_a8w8_repr)
def _gemm_a8w8_kernel(
# Pointers to matrices
a_ptr,
Expand Down
34 changes: 32 additions & 2 deletions aiter/ops/triton/_triton_kernels/gemm_a8w8_blockscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,36 @@
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.kernel_repr import make_kernel_repr


_gemm_a8w8_blockscale_repr = make_kernel_repr(
"_gemm_a8w8_blockscale_kernel",
[
"GROUP_K",
"GROUP_N",
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"NUM_KSPLIT",
"SPLITK_BLOCK_SIZE",
"EVEN_K",
"GRID_MN",
"cache_modifier",
],
)


_gemm_a8w8_blockscale_reduce_repr = make_kernel_repr(
"_gemm_a8w8_blockscale_reduce_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"ACTUAL_KSPLIT",
"MAX_KSPLIT",
],
)


@triton.heuristics(
Expand All @@ -20,7 +50,7 @@
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
@triton.jit(repr=_gemm_a8w8_blockscale_repr)
def _gemm_a8w8_blockscale_kernel(
# Pointers to matrices
a_ptr,
Expand Down Expand Up @@ -195,7 +225,7 @@ def _gemm_a8w8_blockscale_kernel(
tl.store(c_ptrs, c, mask=c_mask)


@triton.jit
@triton.jit(repr=_gemm_a8w8_blockscale_reduce_repr)
def _gemm_a8w8_blockscale_reduce_kernel(
c_in_ptr,
c_out_ptr,
Expand Down
32 changes: 30 additions & 2 deletions aiter/ops/triton/_triton_kernels/gemm_a8w8_per_token_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,34 @@
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.kernel_repr import make_kernel_repr


_gemm_a8w8_per_token_scale_repr = make_kernel_repr(
"_gemm_a8w8_per_token_scale_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"NUM_KSPLIT",
"SPLITK_BLOCK_SIZE",
"EVEN_K",
"GRID_MN",
"cache_modifier",
],
)


_gemm_a8w8_per_token_scale_reduce_repr = make_kernel_repr(
"_gemm_a8w8_per_token_scale_reduce_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"ACTUAL_KSPLIT",
"MAX_KSPLIT",
],
)


@triton.heuristics(
Expand All @@ -18,7 +46,7 @@
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
@triton.jit(repr=_gemm_a8w8_per_token_scale_repr)
def _gemm_a8w8_per_token_scale_kernel(
# Pointers to matrices
a_ptr,
Expand Down Expand Up @@ -167,7 +195,7 @@ def _gemm_a8w8_per_token_scale_kernel(
tl.store(c_ptrs, c, mask=c_mask)


@triton.jit
@triton.jit(repr=_gemm_a8w8_per_token_scale_reduce_repr)
def _gemm_a8w8_per_token_scale_reduce_kernel(
c_in_ptr,
c_out_ptr,
Expand Down
41 changes: 38 additions & 3 deletions aiter/ops/triton/_triton_kernels/gemm_a8wfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,40 @@
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.kernel_repr import make_kernel_repr


_gemm_a8wfp4_repr = make_kernel_repr(
"_gemm_a8wfp4_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"NUM_KSPLIT",
"SPLITK_BLOCK_SIZE",
"EVEN_K",
"GRID_MN",
"RAW_MASKED_LOADS",
"cache_modifier",
],
)

_gemm_afp4_wfp4_reduce_repr = make_kernel_repr(
"_gemm_afp4_wfp4_reduce_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"NUM_KSPLIT",
"SPLITK_BLOCK_SIZE",
"EVEN_K",
"GRID_MN",
"RAW_MASKED_LOADS",
"cache_modifier",
],
)


@triton.heuristics(
Expand All @@ -19,7 +53,7 @@
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
@triton.jit(repr=_gemm_a8wfp4_repr)
def _gemm_a8wfp4_kernel(
a_ptr,
b_ptr,
Expand Down Expand Up @@ -52,7 +86,8 @@ def _gemm_a8wfp4_kernel(
RAW_MASKED_LOADS: tl.constexpr,
cache_modifier: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
"""
Kernel for computing the matmul C = A x B.
A is in fp8 e4m3 format.
B is in the microscale fp4 (mxfp4) format.
A_scales and B_scales are in e8m0 format.
Expand Down Expand Up @@ -183,7 +218,7 @@ def _gemm_a8wfp4_kernel(
tl.store(c_ptrs, c, mask=c_mask)


@triton.jit
@triton.jit(repr=_gemm_afp4_wfp4_reduce_repr)
def _gemm_afp4_wfp4_reduce_kernel(
c_in_ptr,
c_out_ptr,
Expand Down
Loading