Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from .quant import _mxfp4_quant_op

_batched_gemm_afp4_wfp4_pre_quant_repr = make_kernel_repr(
"_batched_gemm_afp4_wfp4_pre_quant_kernel",
_batched_gemm_a16wfp4_repr = make_kernel_repr(
"_batched_gemm_a16wfp4_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
Expand All @@ -25,13 +25,15 @@
"SPLITK_BLOCK_SIZE",
"EVEN_K",
"GRID_MN",
"PRE_QUANT",
"HAVE_Y_SCALE",
"cache_modifier",
],
)


_batched_gemm_afp4_wfp4_pre_quant_reduce_repr = make_kernel_repr(
"_batched_gemm_afp4_wfp4_pre_quant_reduce_kernel",
_batched_gemm_a16wfp4_reduce_repr = make_kernel_repr(
"_batched_gemm_a16wfp4_reduce_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
Expand All @@ -50,12 +52,13 @@
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit(repr=_batched_gemm_afp4_wfp4_pre_quant_repr)
def _batched_gemm_afp4_wfp4_pre_quant_kernel(
@triton.jit(repr=_batched_gemm_a16wfp4_repr)
def _batched_gemm_a16wfp4_kernel(
a_ptr,
b_ptr,
c_ptr,
b_scales_ptr,
c_scale_ptr,
M,
N,
K,
Expand All @@ -81,6 +84,8 @@ def _batched_gemm_afp4_wfp4_pre_quant_kernel(
SPLITK_BLOCK_SIZE: tl.constexpr,
EVEN_K: tl.constexpr,
GRID_MN: tl.constexpr,
PRE_QUANT: tl.constexpr,
HAVE_Y_SCALE: tl.constexpr,
cache_modifier: tl.constexpr,
):
"""
Expand Down Expand Up @@ -121,6 +126,12 @@ def _batched_gemm_afp4_wfp4_pre_quant_kernel(
stride_cb = tl.cast(stride_cb, tl.int64)
pid_batch = tl.cast(pid_batch, tl.int64)

if HAVE_Y_SCALE:
c_scale = tl.load(c_scale_ptr)
else:
c_scale = 1
c_scale_rcprl = (1 / c_scale).to(tl.float32)

if NUM_KSPLIT == 1:
remap_xcd(pid, GRID_MN)

Expand Down Expand Up @@ -189,7 +200,8 @@ def _batched_gemm_afp4_wfp4_pre_quant_kernel(
b_ptrs, mask=offs_k[:, None] < K - k * (BLOCK_SIZE_K // 2), other=0
)

a, a_scales = _mxfp4_quant_op(a_bf16, BLOCK_SIZE_K, BLOCK_SIZE_M, 32)
if PRE_QUANT: # TODO add PRE_QUANT = False
a, a_scales = _mxfp4_quant_op(a_bf16, BLOCK_SIZE_K, BLOCK_SIZE_M, 32)

accumulator += tl.dot_scaled(a, a_scales, "e2m1", b, b_scales, "e2m1")

Expand All @@ -198,6 +210,9 @@ def _batched_gemm_afp4_wfp4_pre_quant_kernel(
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
b_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_bsk

if HAVE_Y_SCALE:
accumulator = accumulator * c_scale_rcprl

c = accumulator.to(c_ptr.type.element_ty)

# Write back the block of the output matrix C with masks.
Expand All @@ -214,8 +229,8 @@ def _batched_gemm_afp4_wfp4_pre_quant_kernel(
tl.store(c_ptrs, c, mask=c_mask)


@triton.jit(repr=_batched_gemm_afp4_wfp4_pre_quant_reduce_repr)
def _batched_gemm_afp4_wfp4_pre_quant_reduce_kernel(
@triton.jit(repr=_batched_gemm_a16wfp4_reduce_repr)
def _batched_gemm_a16wfp4_reduce_kernel(
c_in_ptr,
c_out_ptr,
M,
Expand Down
Loading