Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions aiter/ops/triton/_triton_kernels/activation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .quant import _mxfp4_quant_op
from .fused_fp8_quant import _fp8_quant_op
import triton
import triton.language as tl

Expand Down Expand Up @@ -188,3 +189,87 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel(
bs_e8m0,
mask=bs_mask,
)


@triton.heuristics(
{
"EVEN_N": lambda args: args["N"] % args["BLOCK_SIZE_N"] == 0,
}
)
@triton.jit
def _act_mul_and_dynamic_fp8_group_quant_kernel(
x_ptr,
x_fp8_ptr,
x_bs_ptr,
stride_x_m_in,
stride_x_n_in,
stride_x_fp8_m_in,
stride_x_fp8_n_in,
stride_bs_m_in,
stride_bs_n_in,
N,
ACTIVATION: tl.constexpr,
scaleN: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
QUANT_BLOCK_SIZE: tl.constexpr,
DTYPE_MAX: tl.constexpr,
DTYPE_MIN: tl.constexpr,
EVEN_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# cast strides to int64, in case M*N > max int32
stride_x_m = tl.cast(stride_x_m_in, tl.int64)
stride_x_n = tl.cast(stride_x_n_in, tl.int64)
stride_x_fp8_m = tl.cast(stride_x_fp8_m_in, tl.int64)
stride_x_fp8_n = tl.cast(stride_x_fp8_n_in, tl.int64)
stride_bs_m = tl.cast(stride_bs_m_in, tl.int64)
stride_bs_n = tl.cast(stride_bs_n_in, tl.int64)
NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE

x_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
x_offs = pid_m * stride_x_m + x_offs_n * stride_x_n

if EVEN_N:
a = tl.load(x_ptr + x_offs, cache_modifier=".cg").to(tl.float32)
b = tl.load(x_ptr + x_offs + stride_x_n * N, cache_modifier=".cg").to(
tl.float32
)
else:
x_mask = x_offs_n < N
a = tl.load(x_ptr + x_offs, mask=x_mask, cache_modifier=".cg").to(tl.float32)
# a and b can share the same mask
b = tl.load(
x_ptr + x_offs + stride_x_n * N, mask=x_mask, cache_modifier=".cg"
).to(tl.float32)

x = _apply_activation_from_str(a, ACTIVATION) * b

x_fp8, x_bs = _fp8_quant_op(
x, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN
)
x_fp8 = tl.ravel(x_fp8)
x_bs = tl.ravel(x_bs)

out_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_offs = pid_m * stride_x_fp8_m + out_offs_n * stride_x_fp8_n

if EVEN_N:
tl.store(x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.dtype.element_ty))
else:
out_mask = out_offs_n < N
tl.store(
x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.dtype.element_ty), mask=out_mask
)

bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS)
bs_offs = pid_m * stride_bs_m + bs_offs_n * stride_bs_n
if EVEN_N:
tl.store(x_bs_ptr + bs_offs, x_bs.to(x_bs_ptr.dtype.element_ty))
else:
bs_mask = bs_offs_n < scaleN
tl.store(
x_bs_ptr + bs_offs,
x_bs.to(x_bs_ptr.dtype.element_ty),
mask=bs_mask,
)
81 changes: 81 additions & 0 deletions aiter/ops/triton/_triton_kernels/fused_add_rmsnorm_pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import triton
import triton.language as tl


@triton.jit
def _rmsmorm_op(row, weight, n_cols, epsilon):
row_norm = row * row
row_norm = tl.sum(row_norm, axis=-1)
norm_factor = tl.math.rsqrt((row_norm / n_cols) + epsilon)

rms_norm = row * norm_factor * weight
return rms_norm


@triton.jit
def _fused_add_rmsnorm_pad(
x_ptr,
res_ptr,
out_ptr,
res_out_ptr,
weight_ptr,
eps,
M,
N,
N_OUT,
x_stride_m,
x_stride_n,
res_stride_m,
res_stride_n,
out_stride_m,
out_stride_n,
res_out_stride_m,
res_out_stride_n,
HAS_RES: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
tl.assume(x_stride_m > 0)
tl.assume(x_stride_n > 0)
tl.assume(res_stride_m > 0)
tl.assume(res_stride_n > 0)
tl.assume(out_stride_m > 0)
tl.assume(out_stride_n > 0)

pid_m = tl.program_id(0)
tl.assume(pid_m >= 0)

n_offs = tl.arange(0, BLOCK_SIZE_N)
mask = n_offs < N
x = tl.load(
x_ptr + pid_m * x_stride_m + n_offs * x_stride_n,
mask=mask,
other=0.0,
cache_modifier=".cg",
).to(tl.float32)
if HAS_RES:
res = tl.load(
res_ptr + pid_m * res_stride_m + n_offs * res_stride_n,
mask=mask,
other=0.0,
cache_modifier=".cg",
).to(tl.float32)
x = x + res

w = tl.load(
weight_ptr + n_offs,
mask=mask,
other=0.0,
).to(tl.float32)
out = _rmsmorm_op(x, w, N, eps).to(out_ptr.dtype.element_ty)

tl.store(
out_ptr + pid_m * out_stride_m + n_offs * out_stride_n,
out,
mask=(n_offs < N_OUT),
)
if HAS_RES:
tl.store(
res_out_ptr + pid_m * res_out_stride_m + n_offs * res_out_stride_n,
x,
mask=mask,
)
Loading