Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions aiter/ops/triton/_triton_kernels/moe/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import triton
import triton.language as tl


@triton.jit
def clip(x, limit, clip_lower: tl.constexpr):
res = tl.minimum(x, limit)
if clip_lower:
res = tl.maximum(-limit, res)
return res


@triton.jit
def _swiglu(input, alpha, limit, ADD_RESIDUAL: tl.constexpr):
"""
SwiGLU activation

s = silu(gelu), then returns s * (linear + 1) if ADD_RESIDUAL else s * linear.
if alpha=1.0, then this is the same as the SiLU activation.
"""
gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2)))
gelu = gelu.to(tl.float32)
if limit is not None:
gelu = clip(gelu, limit, clip_lower=False)
linear = linear.to(tl.float32)
if limit is not None:
linear = clip(linear, limit, clip_lower=True)
s = gelu / (1 + tl.exp2(-1.44269504089 * alpha * gelu))
if ADD_RESIDUAL:
return tl.fma(s, linear, s) # s * (linear + 1)
else:
return s * linear
94 changes: 3 additions & 91 deletions aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a4w4.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid
from aiter.ops.triton._triton_kernels.moe.quant_moe import _compute_static_fp8_quant
from aiter.ops.triton._triton_kernels.quant.quant import _mxfp4_quant_op
from aiter.ops.triton._triton_kernels.moe.activations import _swiglu


def matmul_launch_metadata(grid, kernel, args):
Expand Down Expand Up @@ -105,96 +106,6 @@ def unswizzle_mx_scale_cdna4(
return x


@triton.jit
def clip(x, limit, clip_lower: tl.constexpr):
res = tl.minimum(x, limit)
if clip_lower:
res = tl.maximum(-limit, res)
return res


@triton.jit
def _swiglu(input, alpha, limit):
gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2)))
gelu = gelu.to(tl.float32)
if limit is not None:
gelu = clip(gelu, limit, clip_lower=False)
linear = linear.to(tl.float32)
if limit is not None:
linear = clip(linear, limit, clip_lower=True)
s = gelu / (1 + tl.exp2(-1.44269504089 * alpha * gelu))
return tl.fma(s, linear, s) # (s * (linear + 1))


@triton.jit
def _reduce_grouped(
X,
stride_xb: tl.uint64,
stride_xm: tl.uint64,
stride_xn, #
Out,
stride_om: tl.uint64,
stride_on, # output tensor
InIndx,
B,
N, #
# fused activation function
APPLY_SWIGLU: tl.constexpr,
alpha,
limit,
ACTIVATION_REDUCTION_N: tl.constexpr,
K: tl.constexpr,
BLOCK_N: tl.constexpr,
EVEN_N: tl.constexpr,
):
pid_t = tl.program_id(1)
pid_n = tl.program_id(0)

BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
start = pid_t * K
# load indices into a tuple
if InIndx is None:
indxs = (pid_t,)
else:
indxs = ()
for i in tl.static_range(0, K):
indxs = indxs + (tl.load(InIndx + start + i),)
XPtrs = X + (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) * stride_xn
OutPtrs = Out + (pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT)) * stride_on

acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32)
x_n_mask = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) < N
# accumulate contributions for this tile
for i in tl.static_range(0, K):
curr = tl.zeros([BLOCK_N], dtype=tl.float32)
# iterate over split_k partial values
for b in tl.range(0, B):
x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb
if EVEN_N:
vals = tl.load(x_row_ptr)
else:
vals = tl.load(x_row_ptr, mask=x_n_mask, other=0.0)
vals = vals.to(tl.float32)
curr += vals

# apply nonlinearity to split-k output
if APPLY_SWIGLU:
curr = _swiglu(curr[None, :], alpha, limit)
curr = tl.reshape(curr, [curr.shape[-1]])
# update final accumulator
acc += curr
# Compute per-32-col MXFP scales for this tile if requested
Nrem = N // ACTIVATION_REDUCTION_N

# write-back for this tile
out_ptr = OutPtrs + pid_t * stride_om
if EVEN_N:
tl.store(out_ptr, acc)
else:
out_n_mask = pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT) < Nrem
tl.store(out_ptr, acc, mask=out_n_mask)


@triton.jit
def _mxfp4_quant_kernel(
x_ptr,
Expand Down Expand Up @@ -298,6 +209,7 @@ def _moe_gemm_a4w4(
alpha,
limit,
ACTIVATION_REDUCTION_N: tl.constexpr,
ADD_RESIDUAL: tl.constexpr,
# MoE config
N_EXPTS_ACT: tl.constexpr,
# optimization config
Expand Down Expand Up @@ -556,7 +468,7 @@ def _moe_gemm_a4w4(
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
acc = acc + bias[None, :]
if APPLY_SWIGLU and SPLIT_K == 1:
out = _swiglu(acc, alpha, limit)
out = _swiglu(acc, alpha, limit, ADD_RESIDUAL=ADD_RESIDUAL)
tl.static_assert(
out.shape[1] == OUT_BLOCK_N,
f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})",
Expand Down
94 changes: 3 additions & 91 deletions aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w4.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import triton.language as tl
from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid
from aiter.ops.triton._triton_kernels.moe.quant_moe import _compute_static_fp8_quant
from aiter.ops.triton._triton_kernels.moe.activations import _swiglu


def matmul_launch_metadata(grid, kernel, args):
Expand Down Expand Up @@ -104,96 +105,6 @@ def unswizzle_mx_scale_cdna4(
return x


@triton.jit
def clip(x, limit, clip_lower: tl.constexpr):
res = tl.minimum(x, limit)
if clip_lower:
res = tl.maximum(-limit, res)
return res


@triton.jit
def _swiglu(input, alpha, limit):
gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2)))
gelu = gelu.to(tl.float32)
if limit is not None:
gelu = clip(gelu, limit, clip_lower=False)
linear = linear.to(tl.float32)
if limit is not None:
linear = clip(linear, limit, clip_lower=True)
s = gelu / (1 + tl.exp2(-1.44269504089 * alpha * gelu))
return tl.fma(s, linear, s) # (s * (linear + 1))


@triton.jit
def _reduce_grouped(
X,
stride_xb: tl.uint64,
stride_xm: tl.uint64,
stride_xn, #
Out,
stride_om: tl.uint64,
stride_on, # output tensor
InIndx,
B,
N, #
# fused activation function
APPLY_SWIGLU: tl.constexpr,
alpha,
limit,
ACTIVATION_REDUCTION_N: tl.constexpr,
K: tl.constexpr,
BLOCK_N: tl.constexpr,
EVEN_N: tl.constexpr,
):
pid_t = tl.program_id(1)
pid_n = tl.program_id(0)

BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
start = pid_t * K
# load indices into a tuple
if InIndx is None:
indxs = (pid_t,)
else:
indxs = ()
for i in tl.static_range(0, K):
indxs = indxs + (tl.load(InIndx + start + i),)
XPtrs = X + (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) * stride_xn
OutPtrs = Out + (pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT)) * stride_on

acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32)
x_n_mask = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) < N
# accumulate contributions for this tile
for i in tl.static_range(0, K):
curr = tl.zeros([BLOCK_N], dtype=tl.float32)
# iterate over split_k partial values
for b in tl.range(0, B):
x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb
if EVEN_N:
vals = tl.load(x_row_ptr)
else:
vals = tl.load(x_row_ptr, mask=x_n_mask, other=0.0)
vals = vals.to(tl.float32)
curr += vals

# apply nonlinearity to split-k output
if APPLY_SWIGLU:
curr = _swiglu(curr[None, :], alpha, limit)
curr = tl.reshape(curr, [curr.shape[-1]])
# update final accumulator
acc += curr
# Compute per-32-col MXFP scales for this tile if requested
Nrem = N // ACTIVATION_REDUCTION_N

# write-back for this tile
out_ptr = OutPtrs + pid_t * stride_om
if EVEN_N:
tl.store(out_ptr, acc)
else:
out_n_mask = pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT) < Nrem
tl.store(out_ptr, acc, mask=out_n_mask)


@triton.jit(launch_metadata=matmul_launch_metadata)
def _moe_gemm_a8w4(
Y,
Expand Down Expand Up @@ -235,6 +146,7 @@ def _moe_gemm_a8w4(
alpha,
limit,
ACTIVATION_REDUCTION_N: tl.constexpr,
ADD_RESIDUAL: tl.constexpr,
# MoE config
N_EXPTS_ACT: tl.constexpr,
# optimization config
Expand Down Expand Up @@ -481,7 +393,7 @@ def _moe_gemm_a8w4(
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
acc = acc + bias[None, :]
if APPLY_SWIGLU and SPLIT_K == 1:
out = _swiglu(acc, alpha, limit)
out = _swiglu(acc, alpha, limit, ADD_RESIDUAL=ADD_RESIDUAL)
tl.static_assert(
out.shape[1] == OUT_BLOCK_N,
f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})",
Expand Down
94 changes: 3 additions & 91 deletions aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import triton.language as tl
from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid
from aiter.ops.triton._triton_kernels.moe.quant_moe import _compute_static_fp8_quant
from aiter.ops.triton._triton_kernels.moe.activations import _swiglu


def matmul_launch_metadata(grid, kernel, args):
Expand Down Expand Up @@ -104,96 +105,6 @@ def unswizzle_mx_scale_cdna4(
return x


@triton.jit
def clip(x, limit, clip_lower: tl.constexpr):
res = tl.minimum(x, limit)
if clip_lower:
res = tl.maximum(-limit, res)
return res


@triton.jit
def _swiglu(input, alpha, limit):
gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2)))
gelu = gelu.to(tl.float32)
if limit is not None:
gelu = clip(gelu, limit, clip_lower=False)
linear = linear.to(tl.float32)
if limit is not None:
linear = clip(linear, limit, clip_lower=True)
s = gelu / (1 + tl.exp2(-1.44269504089 * alpha * gelu))
return tl.fma(s, linear, s) # (s * (linear + 1))


@triton.jit
def _reduce_grouped(
X,
stride_xb: tl.uint64,
stride_xm: tl.uint64,
stride_xn, #
Out,
stride_om: tl.uint64,
stride_on, # output tensor
InIndx,
B,
N, #
# fused activation function
APPLY_SWIGLU: tl.constexpr,
alpha,
limit,
ACTIVATION_REDUCTION_N: tl.constexpr,
K: tl.constexpr,
BLOCK_N: tl.constexpr,
EVEN_N: tl.constexpr,
):
pid_t = tl.program_id(1)
pid_n = tl.program_id(0)

BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
start = pid_t * K
# load indices into a tuple
if InIndx is None:
indxs = (pid_t,)
else:
indxs = ()
for i in tl.static_range(0, K):
indxs = indxs + (tl.load(InIndx + start + i),)
XPtrs = X + (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) * stride_xn
OutPtrs = Out + (pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT)) * stride_on

acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32)
x_n_mask = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) < N
# accumulate contributions for this tile
for i in tl.static_range(0, K):
curr = tl.zeros([BLOCK_N], dtype=tl.float32)
# iterate over split_k partial values
for b in tl.range(0, B):
x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb
if EVEN_N:
vals = tl.load(x_row_ptr)
else:
vals = tl.load(x_row_ptr, mask=x_n_mask, other=0.0)
vals = vals.to(tl.float32)
curr += vals

# apply nonlinearity to split-k output
if APPLY_SWIGLU:
curr = _swiglu(curr[None, :], alpha, limit)
curr = tl.reshape(curr, [curr.shape[-1]])
# update final accumulator
acc += curr
# Compute per-32-col MXFP scales for this tile if requested
Nrem = N // ACTIVATION_REDUCTION_N

# write-back for this tile
out_ptr = OutPtrs + pid_t * stride_om
if EVEN_N:
tl.store(out_ptr, acc)
else:
out_n_mask = pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT) < Nrem
tl.store(out_ptr, acc, mask=out_n_mask)


@triton.jit(launch_metadata=matmul_launch_metadata)
def _moe_gemm_a8w8(
Y,
Expand Down Expand Up @@ -236,6 +147,7 @@ def _moe_gemm_a8w8(
alpha,
limit,
ACTIVATION_REDUCTION_N: tl.constexpr,
ADD_RESIDUAL: tl.constexpr,
# MoE config
N_EXPTS_ACT: tl.constexpr,
# optimization config
Expand Down Expand Up @@ -491,7 +403,7 @@ def _moe_gemm_a8w8(
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
acc = acc + bias[None, :]
if APPLY_SWIGLU and SPLIT_K == 1:
out = _swiglu(acc, alpha, limit)
out = _swiglu(acc, alpha, limit, ADD_RESIDUAL=ADD_RESIDUAL)
tl.static_assert(
out.shape[1] == OUT_BLOCK_N,
f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})",
Expand Down
Loading
Loading