Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions aiter/ops/triton/_triton_kernels/fused_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,152 @@ def _fused_flatten_fp8_group_quant_kernel(
out_block_scales.to(out_scales_ptr.dtype.element_ty),
mask=block_scale_offs < tl.cdiv(N2, QUANT_BLOCK_SIZE),
)


@triton.jit
def _fused_reduce_act_mul_fp8_group_quant(
x_ptr,
y_ptr,
y_scale_ptr,
x2_ptr,
y2_ptr,
M,
N1,
N2,
stride_x_spk,
stride_x_m,
stride_x_n,
stride_y_m,
stride_y_n,
stride_y_scale_m,
stride_y_scale_n,
stride_x2_spk,
stride_x2_m,
stride_x2_n,
stride_y2_m,
stride_y2_n,
# Meta-parameters
ACTIVATION: tl.constexpr,
BLOCK_SIZE_M2: tl.constexpr,
BLOCK_SIZE_N1: tl.constexpr,
BLOCK_SIZE_N2: tl.constexpr,
QUANT_BLOCK_SIZE: tl.constexpr,
DTYPE_MAX: tl.constexpr,
DTYPE_MIN: tl.constexpr,
X_HAS_SPLITK: tl.constexpr,
X_NUM_KSPLIT: tl.constexpr,
X_NUM_KSPLIT_POW2: tl.constexpr,
X_MASK: tl.constexpr,
):

tl.assume(stride_x_spk > 0)
tl.assume(stride_x_m > 0)
tl.assume(stride_x_n > 0)
tl.assume(stride_y_m > 0)
tl.assume(stride_y_n > 0)
tl.assume(stride_y_scale_m > 0)
tl.assume(stride_y_scale_n > 0)
tl.assume(stride_x2_spk > 0)
tl.assume(stride_x2_m > 0)
tl.assume(stride_x2_n > 0)
tl.assume(stride_y2_m > 0)
tl.assume(stride_y2_n > 0)

m_pid = tl.program_id(axis=0)
if X_HAS_SPLITK and m_pid >= M:
pid2 = m_pid - M
num_pid_n2 = tl.cdiv(N2, BLOCK_SIZE_N2)
pid_m2 = pid2 // num_pid_n2
pid_n2 = pid2 % num_pid_n2
offs_m2 = (pid_m2 * BLOCK_SIZE_M2 + tl.arange(0, BLOCK_SIZE_M2)) % M
offs_n2 = (pid_n2 * BLOCK_SIZE_N2 + tl.arange(0, BLOCK_SIZE_N2)) % N2
offs_spk = tl.arange(0, X_NUM_KSPLIT_POW2)
x2_ptrs = (
x2_ptr
+ offs_spk[:, None, None] * stride_x2_spk
+ offs_m2[None, :, None] * stride_x2_m
+ offs_n2[None, None, :] * stride_x2_n
)
if X_NUM_KSPLIT_POW2 == X_NUM_KSPLIT:
x2 = tl.load(x2_ptrs)
else:
x2 = tl.load(
x2_ptrs, mask=offs_spk[:, None, None] < X_NUM_KSPLIT, other=0.0
)
x2 = tl.sum(x2, axis=0)

x2 = x2.to(y2_ptr.type.element_ty)

y2_out_ptrs = (
y2_ptr + (offs_m2[:, None] * stride_y2_m) + (offs_n2[None, :] * stride_y2_n)
)

tl.store(y2_out_ptrs, x2)
return

n_offs = tl.arange(0, BLOCK_SIZE_N1)
NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N1 // QUANT_BLOCK_SIZE

mask = None
other = None
if X_HAS_SPLITK:
offs_spk = tl.arange(0, X_NUM_KSPLIT_POW2)
x_ptrs = (
x_ptr
+ offs_spk[:, None] * stride_x_spk
+ m_pid * stride_x_m
+ n_offs[None, :] * stride_x_n
)
if X_MASK:
mask = (offs_spk[:, None] < X_NUM_KSPLIT) & (n_offs[None, :] < N1)
other = 0.0
else:
mask = offs_spk[:, None] < X_NUM_KSPLIT
other = 0.0
else:
x_ptrs = x_ptr + m_pid * stride_x_m + n_offs * stride_x_n
if X_MASK:
mask = n_offs < N1
other = 0.0

x = tl.load(
x_ptrs,
mask=mask,
other=other,
cache_modifier=".cg",
).to(tl.float32)
x_mul = tl.load(
x_ptrs + N1 * stride_x_n,
mask=mask,
other=other,
cache_modifier=".cg",
).to(tl.float32)

if X_HAS_SPLITK:
x = tl.sum(x, axis=0)
x_mul = tl.sum(x_mul, axis=0)

x = ACTIVATION(x) * x_mul

y, y_scale = _fp8_quant_op(
x, 1, BLOCK_SIZE_N1, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN
)
y = tl.ravel(y)
y_scale = tl.ravel(y_scale)

if X_MASK:
mask = n_offs < N1
else:
mask = n_offs < N1
tl.store(
y_ptr + m_pid * stride_y_m + n_offs * stride_y_n,
y.to(y_ptr.dtype.element_ty),
mask=mask,
)
g_offs = tl.arange(0, NUM_QUANT_BLOCKS)
num_bs_cols = (N1 + QUANT_BLOCK_SIZE - 1) // QUANT_BLOCK_SIZE
tl.store(
y_scale_ptr + m_pid * stride_y_scale_m + g_offs * stride_y_scale_n,
y_scale.to(y_scale_ptr.dtype.element_ty),
mask=g_offs < num_bs_cols,
)
Loading
Loading