diff --git a/aiter/ops/triton/_triton_kernels/moe/activations.py b/aiter/ops/triton/_triton_kernels/moe/activations.py new file mode 100644 index 0000000000..c1401db012 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/moe/activations.py @@ -0,0 +1,32 @@ +import triton +import triton.language as tl + + +@triton.jit +def clip(x, limit, clip_lower: tl.constexpr): + res = tl.minimum(x, limit) + if clip_lower: + res = tl.maximum(-limit, res) + return res + + +@triton.jit +def _swiglu(input, alpha, limit, ADD_RESIDUAL: tl.constexpr): + """ + SwiGLU activation + + s = silu(gelu), then returns s * (linear + 1) if ADD_RESIDUAL else s * linear. + if alpha=1.0, then this is the same as the SiLU activation. + """ + gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2))) + gelu = gelu.to(tl.float32) + if limit is not None: + gelu = clip(gelu, limit, clip_lower=False) + linear = linear.to(tl.float32) + if limit is not None: + linear = clip(linear, limit, clip_lower=True) + s = gelu / (1 + tl.exp2(-1.44269504089 * alpha * gelu)) + if ADD_RESIDUAL: + return tl.fma(s, linear, s) # s * (linear + 1) + else: + return s * linear diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a4w4.py b/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a4w4.py index d13329493a..baba601527 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a4w4.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a4w4.py @@ -7,6 +7,7 @@ from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid from aiter.ops.triton._triton_kernels.moe.quant_moe import _compute_static_fp8_quant from aiter.ops.triton._triton_kernels.quant.quant import _mxfp4_quant_op +from aiter.ops.triton._triton_kernels.moe.activations import _swiglu def matmul_launch_metadata(grid, kernel, args): @@ -105,96 +106,6 @@ def unswizzle_mx_scale_cdna4( return x -@triton.jit -def clip(x, limit, clip_lower: tl.constexpr): - res = tl.minimum(x, limit) - if clip_lower: - res = tl.maximum(-limit, res) - return res - - -@triton.jit -def _swiglu(input, alpha, limit): - gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2))) - gelu = gelu.to(tl.float32) - if limit is not None: - gelu = clip(gelu, limit, clip_lower=False) - linear = linear.to(tl.float32) - if limit is not None: - linear = clip(linear, limit, clip_lower=True) - s = gelu / (1 + tl.exp2(-1.44269504089 * alpha * gelu)) - return tl.fma(s, linear, s) # (s * (linear + 1)) - - -@triton.jit -def _reduce_grouped( - X, - stride_xb: tl.uint64, - stride_xm: tl.uint64, - stride_xn, # - Out, - stride_om: tl.uint64, - stride_on, # output tensor - InIndx, - B, - N, # - # fused activation function - APPLY_SWIGLU: tl.constexpr, - alpha, - limit, - ACTIVATION_REDUCTION_N: tl.constexpr, - K: tl.constexpr, - BLOCK_N: tl.constexpr, - EVEN_N: tl.constexpr, -): - pid_t = tl.program_id(1) - pid_n = tl.program_id(0) - - BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N - start = pid_t * K - # load indices into a tuple - if InIndx is None: - indxs = (pid_t,) - else: - indxs = () - for i in tl.static_range(0, K): - indxs = indxs + (tl.load(InIndx + start + i),) - XPtrs = X + (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) * stride_xn - OutPtrs = Out + (pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT)) * stride_on - - acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32) - x_n_mask = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) < N - # accumulate contributions for this tile - for i in tl.static_range(0, K): - curr = tl.zeros([BLOCK_N], dtype=tl.float32) - # iterate over split_k partial values - for b in tl.range(0, B): - x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb - if EVEN_N: - vals = tl.load(x_row_ptr) - else: - vals = tl.load(x_row_ptr, mask=x_n_mask, other=0.0) - vals = vals.to(tl.float32) - curr += vals - - # apply nonlinearity to split-k output - if APPLY_SWIGLU: - curr = _swiglu(curr[None, :], alpha, limit) - curr = tl.reshape(curr, [curr.shape[-1]]) - # update final accumulator - acc += curr - # Compute per-32-col MXFP scales for this tile if requested - Nrem = N // ACTIVATION_REDUCTION_N - - # write-back for this tile - out_ptr = OutPtrs + pid_t * stride_om - if EVEN_N: - tl.store(out_ptr, acc) - else: - out_n_mask = pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT) < Nrem - tl.store(out_ptr, acc, mask=out_n_mask) - - @triton.jit def _mxfp4_quant_kernel( x_ptr, @@ -298,6 +209,7 @@ def _moe_gemm_a4w4( alpha, limit, ACTIVATION_REDUCTION_N: tl.constexpr, + ADD_RESIDUAL: tl.constexpr, # MoE config N_EXPTS_ACT: tl.constexpr, # optimization config @@ -556,7 +468,7 @@ def _moe_gemm_a4w4( bias = tl.full([BLOCK_N], 0, dtype=tl.float32) acc = acc + bias[None, :] if APPLY_SWIGLU and SPLIT_K == 1: - out = _swiglu(acc, alpha, limit) + out = _swiglu(acc, alpha, limit, ADD_RESIDUAL=ADD_RESIDUAL) tl.static_assert( out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})", diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w4.py index 75d0b4f5ee..b71090c143 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w4.py @@ -6,6 +6,7 @@ import triton.language as tl from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid from aiter.ops.triton._triton_kernels.moe.quant_moe import _compute_static_fp8_quant +from aiter.ops.triton._triton_kernels.moe.activations import _swiglu def matmul_launch_metadata(grid, kernel, args): @@ -104,96 +105,6 @@ def unswizzle_mx_scale_cdna4( return x -@triton.jit -def clip(x, limit, clip_lower: tl.constexpr): - res = tl.minimum(x, limit) - if clip_lower: - res = tl.maximum(-limit, res) - return res - - -@triton.jit -def _swiglu(input, alpha, limit): - gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2))) - gelu = gelu.to(tl.float32) - if limit is not None: - gelu = clip(gelu, limit, clip_lower=False) - linear = linear.to(tl.float32) - if limit is not None: - linear = clip(linear, limit, clip_lower=True) - s = gelu / (1 + tl.exp2(-1.44269504089 * alpha * gelu)) - return tl.fma(s, linear, s) # (s * (linear + 1)) - - -@triton.jit -def _reduce_grouped( - X, - stride_xb: tl.uint64, - stride_xm: tl.uint64, - stride_xn, # - Out, - stride_om: tl.uint64, - stride_on, # output tensor - InIndx, - B, - N, # - # fused activation function - APPLY_SWIGLU: tl.constexpr, - alpha, - limit, - ACTIVATION_REDUCTION_N: tl.constexpr, - K: tl.constexpr, - BLOCK_N: tl.constexpr, - EVEN_N: tl.constexpr, -): - pid_t = tl.program_id(1) - pid_n = tl.program_id(0) - - BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N - start = pid_t * K - # load indices into a tuple - if InIndx is None: - indxs = (pid_t,) - else: - indxs = () - for i in tl.static_range(0, K): - indxs = indxs + (tl.load(InIndx + start + i),) - XPtrs = X + (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) * stride_xn - OutPtrs = Out + (pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT)) * stride_on - - acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32) - x_n_mask = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) < N - # accumulate contributions for this tile - for i in tl.static_range(0, K): - curr = tl.zeros([BLOCK_N], dtype=tl.float32) - # iterate over split_k partial values - for b in tl.range(0, B): - x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb - if EVEN_N: - vals = tl.load(x_row_ptr) - else: - vals = tl.load(x_row_ptr, mask=x_n_mask, other=0.0) - vals = vals.to(tl.float32) - curr += vals - - # apply nonlinearity to split-k output - if APPLY_SWIGLU: - curr = _swiglu(curr[None, :], alpha, limit) - curr = tl.reshape(curr, [curr.shape[-1]]) - # update final accumulator - acc += curr - # Compute per-32-col MXFP scales for this tile if requested - Nrem = N // ACTIVATION_REDUCTION_N - - # write-back for this tile - out_ptr = OutPtrs + pid_t * stride_om - if EVEN_N: - tl.store(out_ptr, acc) - else: - out_n_mask = pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT) < Nrem - tl.store(out_ptr, acc, mask=out_n_mask) - - @triton.jit(launch_metadata=matmul_launch_metadata) def _moe_gemm_a8w4( Y, @@ -235,6 +146,7 @@ def _moe_gemm_a8w4( alpha, limit, ACTIVATION_REDUCTION_N: tl.constexpr, + ADD_RESIDUAL: tl.constexpr, # MoE config N_EXPTS_ACT: tl.constexpr, # optimization config @@ -481,7 +393,7 @@ def _moe_gemm_a8w4( bias = tl.full([BLOCK_N], 0, dtype=tl.float32) acc = acc + bias[None, :] if APPLY_SWIGLU and SPLIT_K == 1: - out = _swiglu(acc, alpha, limit) + out = _swiglu(acc, alpha, limit, ADD_RESIDUAL=ADD_RESIDUAL) tl.static_assert( out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})", diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w8.py b/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w8.py index 24e8e35040..937352e4ad 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w8.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w8.py @@ -6,6 +6,7 @@ import triton.language as tl from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid from aiter.ops.triton._triton_kernels.moe.quant_moe import _compute_static_fp8_quant +from aiter.ops.triton._triton_kernels.moe.activations import _swiglu def matmul_launch_metadata(grid, kernel, args): @@ -104,96 +105,6 @@ def unswizzle_mx_scale_cdna4( return x -@triton.jit -def clip(x, limit, clip_lower: tl.constexpr): - res = tl.minimum(x, limit) - if clip_lower: - res = tl.maximum(-limit, res) - return res - - -@triton.jit -def _swiglu(input, alpha, limit): - gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2))) - gelu = gelu.to(tl.float32) - if limit is not None: - gelu = clip(gelu, limit, clip_lower=False) - linear = linear.to(tl.float32) - if limit is not None: - linear = clip(linear, limit, clip_lower=True) - s = gelu / (1 + tl.exp2(-1.44269504089 * alpha * gelu)) - return tl.fma(s, linear, s) # (s * (linear + 1)) - - -@triton.jit -def _reduce_grouped( - X, - stride_xb: tl.uint64, - stride_xm: tl.uint64, - stride_xn, # - Out, - stride_om: tl.uint64, - stride_on, # output tensor - InIndx, - B, - N, # - # fused activation function - APPLY_SWIGLU: tl.constexpr, - alpha, - limit, - ACTIVATION_REDUCTION_N: tl.constexpr, - K: tl.constexpr, - BLOCK_N: tl.constexpr, - EVEN_N: tl.constexpr, -): - pid_t = tl.program_id(1) - pid_n = tl.program_id(0) - - BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N - start = pid_t * K - # load indices into a tuple - if InIndx is None: - indxs = (pid_t,) - else: - indxs = () - for i in tl.static_range(0, K): - indxs = indxs + (tl.load(InIndx + start + i),) - XPtrs = X + (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) * stride_xn - OutPtrs = Out + (pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT)) * stride_on - - acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32) - x_n_mask = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) < N - # accumulate contributions for this tile - for i in tl.static_range(0, K): - curr = tl.zeros([BLOCK_N], dtype=tl.float32) - # iterate over split_k partial values - for b in tl.range(0, B): - x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb - if EVEN_N: - vals = tl.load(x_row_ptr) - else: - vals = tl.load(x_row_ptr, mask=x_n_mask, other=0.0) - vals = vals.to(tl.float32) - curr += vals - - # apply nonlinearity to split-k output - if APPLY_SWIGLU: - curr = _swiglu(curr[None, :], alpha, limit) - curr = tl.reshape(curr, [curr.shape[-1]]) - # update final accumulator - acc += curr - # Compute per-32-col MXFP scales for this tile if requested - Nrem = N // ACTIVATION_REDUCTION_N - - # write-back for this tile - out_ptr = OutPtrs + pid_t * stride_om - if EVEN_N: - tl.store(out_ptr, acc) - else: - out_n_mask = pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT) < Nrem - tl.store(out_ptr, acc, mask=out_n_mask) - - @triton.jit(launch_metadata=matmul_launch_metadata) def _moe_gemm_a8w8( Y, @@ -236,6 +147,7 @@ def _moe_gemm_a8w8( alpha, limit, ACTIVATION_REDUCTION_N: tl.constexpr, + ADD_RESIDUAL: tl.constexpr, # MoE config N_EXPTS_ACT: tl.constexpr, # optimization config @@ -491,7 +403,7 @@ def _moe_gemm_a8w8( bias = tl.full([BLOCK_N], 0, dtype=tl.float32) acc = acc + bias[None, :] if APPLY_SWIGLU and SPLIT_K == 1: - out = _swiglu(acc, alpha, limit) + out = _swiglu(acc, alpha, limit, ADD_RESIDUAL=ADD_RESIDUAL) tl.static_assert( out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})", diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w8_blockscale.py b/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w8_blockscale.py index 5bf20a9817..46c8358aed 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w8_blockscale.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_a8w8_blockscale.py @@ -6,6 +6,7 @@ import triton.language as tl from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid from aiter.ops.triton._triton_kernels.moe.quant_moe import _compute_static_fp8_quant +from aiter.ops.triton._triton_kernels.moe.activations import _swiglu def matmul_launch_metadata(grid, kernel, args): @@ -91,96 +92,6 @@ def xcd_swizzle(pid, domain_size, XCD_SWIZZLE: tl.constexpr): return new_pid -@triton.jit -def clip(x, limit, clip_lower: tl.constexpr): - res = tl.minimum(x, limit) - if clip_lower: - res = tl.maximum(-limit, res) - return res - - -@triton.jit -def _swiglu(input, alpha, limit): - gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2))) - gelu = gelu.to(tl.float32) - if limit is not None: - gelu = clip(gelu, limit, clip_lower=False) - linear = linear.to(tl.float32) - if limit is not None: - linear = clip(linear, limit, clip_lower=True) - s = gelu / (1 + tl.exp2(-1.44269504089 * alpha * gelu)) - return tl.fma(s, linear, s) # (s * (linear + 1)) - - -@triton.jit -def _reduce_grouped( - X, - stride_xb: tl.uint64, - stride_xm: tl.uint64, - stride_xn, # - Out, - stride_om: tl.uint64, - stride_on, # output tensor - InIndx, - B, - N, # - # fused activation function - APPLY_SWIGLU: tl.constexpr, - alpha, - limit, - ACTIVATION_REDUCTION_N: tl.constexpr, - K: tl.constexpr, - BLOCK_N: tl.constexpr, - EVEN_N: tl.constexpr, -): - pid_t = tl.program_id(1) - pid_n = tl.program_id(0) - - BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N - start = pid_t * K - # load indices into a tuple - if InIndx is None: - indxs = (pid_t,) - else: - indxs = () - for i in tl.static_range(0, K): - indxs = indxs + (tl.load(InIndx + start + i),) - XPtrs = X + (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) * stride_xn - OutPtrs = Out + (pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT)) * stride_on - - acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32) - x_n_mask = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) < N - # accumulate contributions for this tile - for i in tl.static_range(0, K): - curr = tl.zeros([BLOCK_N], dtype=tl.float32) - # iterate over split_k partial values - for b in tl.range(0, B): - x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb - if EVEN_N: - vals = tl.load(x_row_ptr) - else: - vals = tl.load(x_row_ptr, mask=x_n_mask, other=0.0) - vals = vals.to(tl.float32) - curr += vals - - # apply nonlinearity to split-k output - if APPLY_SWIGLU: - curr = _swiglu(curr[None, :], alpha, limit) - curr = tl.reshape(curr, [curr.shape[-1]]) - # update final accumulator - acc += curr - # Compute per-32-col MXFP scales for this tile if requested - Nrem = N // ACTIVATION_REDUCTION_N - - # write-back for this tile - out_ptr = OutPtrs + pid_t * stride_om - if EVEN_N: - tl.store(out_ptr, acc) - else: - out_n_mask = pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT) < Nrem - tl.store(out_ptr, acc, mask=out_n_mask) - - @triton.jit(launch_metadata=matmul_launch_metadata) def _moe_gemm_a8w8_blockscale( Y, @@ -223,6 +134,7 @@ def _moe_gemm_a8w8_blockscale( alpha, limit, ACTIVATION_REDUCTION_N: tl.constexpr, + ADD_RESIDUAL: tl.constexpr, # MoE config N_EXPTS_ACT: tl.constexpr, # optimization config @@ -434,7 +346,7 @@ def _moe_gemm_a8w8_blockscale( bias = tl.full([BLOCK_N], 0, dtype=tl.float32) acc = acc + bias[None, :] if APPLY_SWIGLU and SPLIT_K == 1: - out = _swiglu(acc, alpha, limit) + out = _swiglu(acc, alpha, limit, ADD_RESIDUAL=ADD_RESIDUAL) tl.static_assert( out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})", diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_int8_smoothquant.py b/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_int8_smoothquant.py index cb96ba783a..d3d91ca99d 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_int8_smoothquant.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_op_gemm_int8_smoothquant.py @@ -5,6 +5,7 @@ import triton import triton.language as tl from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid +from aiter.ops.triton._triton_kernels.moe.activations import _swiglu def matmul_launch_metadata(grid, kernel, args): @@ -66,30 +67,6 @@ def repr(s, x): return ret -@triton.jit -def clip(x, limit, clip_lower: tl.constexpr): - res = tl.minimum(x, limit) - if clip_lower: - res = tl.maximum(-limit, res) - return res - - -@triton.jit -def _swiglu(input, alpha, limit, ADD_RESIDUAL: tl.constexpr): - gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2))) - gelu = gelu.to(tl.float32) - if limit is not None: - gelu = clip(gelu, limit, clip_lower=False) - linear = linear.to(tl.float32) - if limit is not None: - linear = clip(linear, limit, clip_lower=True) - s = gelu / (1 + tl.exp2(-1.44269504089 * alpha * gelu)) - if ADD_RESIDUAL: - return tl.fma(s, linear, s) # s * (linear + 1) - else: - return s * linear - - @triton.jit def unshuffle_weights(w, BLOCK_N, BLOCK_K): w = w.trans() @@ -100,75 +77,6 @@ def unshuffle_weights(w, BLOCK_N, BLOCK_K): return w -@triton.jit -def _reduce_grouped( - X, - stride_xb: tl.uint64, - stride_xm: tl.uint64, - stride_xn, - Out, - stride_om: tl.uint64, - stride_on, - InIndx, - B, - N, - alpha, - limit, - ACTIVATION_REDUCTION_N: tl.constexpr, - K: tl.constexpr, - BLOCK_N: tl.constexpr, - EVEN_N: tl.constexpr, - ADD_RESIDUAL: tl.constexpr, - APPLY_ACTIVATION: tl.constexpr, -): - pid_t = tl.program_id(1) - pid_n = tl.program_id(0) - - BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N - start = pid_t * K - # load indices into a tuple - if InIndx is None: - indxs = (pid_t,) - else: - indxs = () - for i in tl.static_range(0, K): - indxs = indxs + (tl.load(InIndx + start + i),) - XPtrs = X + (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) * stride_xn - OutPtrs = Out + (pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT)) * stride_on - - acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32) - x_n_mask = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) < N - # accumulate contributions for this tile - for i in tl.static_range(0, K): - curr = tl.zeros([BLOCK_N], dtype=tl.float32) - # iterate over split_k partial values - for b in tl.range(0, B): - x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb - if EVEN_N: - vals = tl.load(x_row_ptr) - else: - vals = tl.load(x_row_ptr, mask=x_n_mask, other=0.0) - vals = vals.to(tl.float32) - curr += vals - - # apply nonlinearity to split-k output - if APPLY_ACTIVATION: - curr = _swiglu(curr[None, :], alpha, limit, ADD_RESIDUAL=ADD_RESIDUAL) - curr = tl.reshape(curr, [curr.shape[-1]]) - # update final accumulator - acc += curr - # Compute per-32-col MXFP scales for this tile if requested - Nrem = N // ACTIVATION_REDUCTION_N - - # write-back for this tile - out_ptr = OutPtrs + pid_t * stride_om - if EVEN_N: - tl.store(out_ptr, acc) - else: - out_n_mask = pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT) < Nrem - tl.store(out_ptr, acc, mask=out_n_mask) - - @triton.jit(launch_metadata=matmul_launch_metadata) def _moe_gemm_int8_smoothquant( Y, diff --git a/aiter/ops/triton/_triton_kernels/moe/reduce.py b/aiter/ops/triton/_triton_kernels/moe/reduce.py new file mode 100644 index 0000000000..665ce7548d --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/moe/reduce.py @@ -0,0 +1,75 @@ +import triton +import triton.language as tl +from aiter.ops.triton._triton_kernels.moe.activations import _swiglu + + +@triton.jit +def _reduce_grouped( + X, + stride_xb: tl.uint64, + stride_xm: tl.uint64, + stride_xn, # + Out, + stride_om: tl.uint64, + stride_on, # output tensor + InIndx, + B, + N, + num_blocks, + # fused activation function + APPLY_SWIGLU: tl.constexpr, + alpha, + limit, + ACTIVATION_REDUCTION_N: tl.constexpr, + K: tl.constexpr, + BLOCK_N: tl.constexpr, + EVEN_N: tl.constexpr, + ADD_RESIDUAL: tl.constexpr, +): + pid = tl.program_id(0) + pid_t = pid // num_blocks + pid_n = pid % num_blocks + + BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N + start = pid_t * K + # load indices into a tuple + if InIndx is None: + indxs = (pid_t,) + else: + indxs = () + for i in tl.static_range(0, K): + indxs = indxs + (tl.load(InIndx + start + i),) + XPtrs = X + (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) * stride_xn + OutPtrs = Out + (pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT)) * stride_on + + acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32) + x_n_mask = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) < N + # accumulate contributions for this tile + for i in tl.static_range(0, K): + curr = tl.zeros([BLOCK_N], dtype=tl.float32) + # iterate over split_k partial values + for b in tl.range(0, B): + x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb + if EVEN_N: + vals = tl.load(x_row_ptr) + else: + vals = tl.load(x_row_ptr, mask=x_n_mask, other=0.0) + vals = vals.to(tl.float32) + curr += vals + + # apply nonlinearity to split-k output + if APPLY_SWIGLU: + curr = _swiglu(curr[None, :], alpha, limit, ADD_RESIDUAL=ADD_RESIDUAL) + curr = tl.reshape(curr, [curr.shape[-1]]) + # update final accumulator + acc += curr + # Compute per-32-col MXFP scales for this tile if requested + Nrem = N // ACTIVATION_REDUCTION_N + + # write-back for this tile + out_ptr = OutPtrs + pid_t * stride_om + if EVEN_N: + tl.store(out_ptr, acc) + else: + out_n_mask = pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT) < Nrem + tl.store(out_ptr, acc, mask=out_n_mask) diff --git a/aiter/ops/triton/moe/moe_op_gemm_a4w4.py b/aiter/ops/triton/moe/moe_op_gemm_a4w4.py index f688c44db1..0b2f70c704 100644 --- a/aiter/ops/triton/moe/moe_op_gemm_a4w4.py +++ b/aiter/ops/triton/moe/moe_op_gemm_a4w4.py @@ -8,8 +8,8 @@ from aiter.ops.triton._triton_kernels.moe.moe_op_gemm_a4w4 import ( _mxfp4_quant_kernel, _moe_gemm_a4w4, - _reduce_grouped, ) +from aiter.ops.triton.moe.reduce import reduce_grouped # ----------------------------------------------------------------------------- # Matrix Multiplication + Outer Gather/Scatter @@ -122,80 +122,6 @@ def swizzle_scales(data): return data.transpose(-1, -2) -def reduce_grouped( - x: torch.Tensor, - indx: torch.Tensor, - out: torch.Tensor, - apply_swiglu=False, - alpha=1.0, - limit=1.0, - reduction_n=1, - out_dtype: bool = None, -): - """ - In-place grouped row reduction. - - Arguments - - x: Tensor[AnyFloat] of shape [(num_groups * K), N] - - indx: Tensor[Int] of shape [num_groups, K] - - Description - For each group g in [0, num_groups), this routine sums the K rows of `x` - specified by `indx[g, :]` and overwrites the row corresponding to the first - valid (non-negative) index with the per-group sum. Accumulation is performed - in float32 for numerical stability, and the result is written back in the - dtype of `x`. - - Behavior and edge cases - - Invalid (-1) entries are skipped during accumulation and do not generate - memory traffic. If a group has no valid entries, nothing is written for - that group. - - Reduction is performed tile-by-tile along the N dimension within a single - kernel launch (persistent along N) to minimize launch overhead. - - Performance notes - - Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x), - plus index reads. With no invalid entries, this becomes (K + 1) reads/writes - of length N per group. - - Returns - - The input tensor `x` (modified in place). - """ - if indx is None and x.shape[0] == 1: - return x.squeeze(0) - if indx is not None: - num_groups = indx.shape[0] - else: - num_groups = x.shape[-2] - K = 1 if indx is None else indx.shape[1] - out_dtype = x.dtype if out_dtype is None else out_dtype - assert x.shape[-1] % reduction_n == 0 - BLOCK_N = 512 - num_blocks = triton.cdiv(x.shape[-1], BLOCK_N) - - _reduce_grouped[(num_blocks, num_groups)]( - x, - x.stride(0), - x.stride(1), - x.stride(2), # - out, - out.stride(0), - out.stride(1), # - indx, # - x.shape[0], - x.shape[-1], # - apply_swiglu, - alpha, - limit, - reduction_n, - BLOCK_N=BLOCK_N, - EVEN_N=(x.shape[-1] % BLOCK_N == 0), - K=K, # - num_warps=2, # - ) - return out - - # ----------------------------------------------------------------------------- # Triton Implementation # ----------------------------------------------------------------------------- @@ -273,6 +199,7 @@ def moe_gemm_a4w4( apply_swiglu=False, alpha=1.0, limit=1.0, + add_residual=True, unpadded_N=None, unpadded_K=None, ): @@ -378,6 +305,7 @@ def moe_gemm_a4w4( alpha, limit, reduction_n_matmul, + add_residual, routing_data.n_expts_act, config["block_m"], config["block_n"], @@ -411,6 +339,7 @@ def moe_gemm_a4w4( limit, reduction_n_reduction, out_dtype=out_dtype, + add_residual=add_residual, ) return y_final @@ -420,7 +349,7 @@ def moe_gemm_a4w4( # ----------------------------------------------------------------------------- -def swiglu_torch(a, alpha, limit): +def swiglu_torch(a, alpha, limit, add_residual=True): a_gelu = a[..., ::2] if limit is not None: a_gelu = a_gelu.clamp(max=limit) @@ -429,7 +358,10 @@ def swiglu_torch(a, alpha, limit): a_linear = a_linear.clamp(min=-limit, max=limit) out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu) - out = out_gelu * (a_linear + 1) + if add_residual: + out = out_gelu * (a_linear + 1) + else: + out = out_gelu * a_linear return out @@ -444,6 +376,7 @@ def moe_gemm_torch( apply_swiglu=False, alpha=1.0, limit=1.0, + add_residual=True, ): assert x.dtype.itemsize > 1 assert w.dtype.itemsize > 1 @@ -473,7 +406,7 @@ def moe_gemm_torch( if bias is not None: out += bias[i, :] if apply_swiglu: - out = swiglu_torch(out, alpha, limit) + out = swiglu_torch(out, alpha, limit, add_residual) if gammas is not None: out *= gammas[lo:hi, None] y[lo:hi, :] = out diff --git a/aiter/ops/triton/moe/moe_op_gemm_a8w4.py b/aiter/ops/triton/moe/moe_op_gemm_a8w4.py index 46102d47d7..e510ff6335 100644 --- a/aiter/ops/triton/moe/moe_op_gemm_a8w4.py +++ b/aiter/ops/triton/moe/moe_op_gemm_a8w4.py @@ -7,8 +7,8 @@ from aiter.ops.triton.moe.moe_routing.routing import RoutingData from aiter.ops.triton._triton_kernels.moe.moe_op_gemm_a8w4 import ( _moe_gemm_a8w4, - _reduce_grouped, ) +from aiter.ops.triton.moe.reduce import reduce_grouped # ----------------------------------------------------------------------------- # Matrix Multiplication + Outer Gather/Scatter @@ -131,80 +131,6 @@ def swizzle_scales(data): return data.transpose(-1, -2) -def reduce_grouped( - x: torch.Tensor, - indx: torch.Tensor, - out: torch.Tensor, - apply_swiglu=False, - alpha=1.0, - limit=1.0, - reduction_n=1, - out_dtype: bool = None, -): - """ - In-place grouped row reduction. - - Arguments - - x: Tensor[AnyFloat] of shape [(num_groups * K), N] - - indx: Tensor[Int] of shape [num_groups, K] - - Description - For each group g in [0, num_groups), this routine sums the K rows of `x` - specified by `indx[g, :]` and overwrites the row corresponding to the first - valid (non-negative) index with the per-group sum. Accumulation is performed - in float32 for numerical stability, and the result is written back in the - dtype of `x`. - - Behavior and edge cases - - Invalid (-1) entries are skipped during accumulation and do not generate - memory traffic. If a group has no valid entries, nothing is written for - that group. - - Reduction is performed tile-by-tile along the N dimension within a single - kernel launch (persistent along N) to minimize launch overhead. - - Performance notes - - Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x), - plus index reads. With no invalid entries, this becomes (K + 1) reads/writes - of length N per group. - - Returns - - The input tensor `x` (modified in place). - """ - if indx is None and x.shape[0] == 1: - return x.squeeze(0) - if indx is not None: - num_groups = indx.shape[0] - else: - num_groups = x.shape[-2] - K = 1 if indx is None else indx.shape[1] - out_dtype = x.dtype if out_dtype is None else out_dtype - assert x.shape[-1] % reduction_n == 0 - BLOCK_N = 512 - num_blocks = triton.cdiv(x.shape[-1], BLOCK_N) - - _reduce_grouped[(num_blocks, num_groups)]( - x, - x.stride(0), - x.stride(1), - x.stride(2), # - out, - out.stride(0), - out.stride(1), # - indx, # - x.shape[0], - x.shape[-1], # - apply_swiglu, - alpha, - limit, - reduction_n, - BLOCK_N=BLOCK_N, - EVEN_N=(x.shape[-1] % BLOCK_N == 0), - K=K, # - num_warps=2, # - ) - return out - - # ----------------------------------------------------------------------------- # Triton Implementation # ----------------------------------------------------------------------------- @@ -227,6 +153,7 @@ def moe_gemm_a8w4( apply_swiglu=False, alpha=1.0, limit=1.0, + add_residual=True, unpadded_N=None, unpadded_K=None, ): @@ -332,6 +259,7 @@ def moe_gemm_a8w4( alpha, limit, reduction_n_matmul, + add_residual, routing_data.n_expts_act, config["block_m"], config["block_n"], @@ -365,6 +293,7 @@ def moe_gemm_a8w4( limit, reduction_n_reduction, out_dtype=out_dtype, + add_residual=add_residual, ) return y_final @@ -374,7 +303,7 @@ def moe_gemm_a8w4( # ----------------------------------------------------------------------------- -def swiglu_torch(a, alpha, limit): +def swiglu_torch(a, alpha, limit, add_residual=True): a_gelu = a[..., ::2] if limit is not None: a_gelu = a_gelu.clamp(max=limit) @@ -383,7 +312,10 @@ def swiglu_torch(a, alpha, limit): a_linear = a_linear.clamp(min=-limit, max=limit) out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu) - out = out_gelu * (a_linear + 1) + if add_residual: + out = out_gelu * (a_linear + 1) + else: + out = out_gelu * a_linear return out @@ -398,6 +330,7 @@ def moe_gemm_torch( apply_swiglu=False, alpha=1.0, limit=1.0, + add_residual=True, ): assert x.dtype.itemsize > 1 assert w.dtype.itemsize > 1 @@ -427,7 +360,7 @@ def moe_gemm_torch( if bias is not None: out += bias[i, :] if apply_swiglu: - out = swiglu_torch(out, alpha, limit) + out = swiglu_torch(out, alpha, limit, add_residual) if gammas is not None: out *= gammas[lo:hi, None] y[lo:hi, :] = out diff --git a/aiter/ops/triton/moe/moe_op_gemm_a8w8.py b/aiter/ops/triton/moe/moe_op_gemm_a8w8.py index 9d42c8bf9f..48709371e6 100644 --- a/aiter/ops/triton/moe/moe_op_gemm_a8w8.py +++ b/aiter/ops/triton/moe/moe_op_gemm_a8w8.py @@ -7,8 +7,8 @@ from aiter.ops.triton.moe.moe_routing.routing import RoutingData from aiter.ops.triton._triton_kernels.moe.moe_op_gemm_a8w8 import ( _moe_gemm_a8w8, - _reduce_grouped, ) +from aiter.ops.triton.moe.reduce import reduce_grouped # ----------------------------------------------------------------------------- # Matrix Multiplication + Outer Gather/Scatter @@ -121,80 +121,6 @@ def swizzle_scales(data): return data.transpose(-1, -2) -def reduce_grouped( - x: torch.Tensor, - indx: torch.Tensor, - out: torch.Tensor, - apply_swiglu=False, - alpha=1.0, - limit=1.0, - reduction_n=1, - out_dtype: bool = None, -): - """ - In-place grouped row reduction. - - Arguments - - x: Tensor[AnyFloat] of shape [(num_groups * K), N] - - indx: Tensor[Int] of shape [num_groups, K] - - Description - For each group g in [0, num_groups), this routine sums the K rows of `x` - specified by `indx[g, :]` and overwrites the row corresponding to the first - valid (non-negative) index with the per-group sum. Accumulation is performed - in float32 for numerical stability, and the result is written back in the - dtype of `x`. - - Behavior and edge cases - - Invalid (-1) entries are skipped during accumulation and do not generate - memory traffic. If a group has no valid entries, nothing is written for - that group. - - Reduction is performed tile-by-tile along the N dimension within a single - kernel launch (persistent along N) to minimize launch overhead. - - Performance notes - - Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x), - plus index reads. With no invalid entries, this becomes (K + 1) reads/writes - of length N per group. - - Returns - - The input tensor `x` (modified in place). - """ - if indx is None and x.shape[0] == 1: - return x.squeeze(0) - if indx is not None: - num_groups = indx.shape[0] - else: - num_groups = x.shape[-2] - K = 1 if indx is None else indx.shape[1] - out_dtype = x.dtype if out_dtype is None else out_dtype - assert x.shape[-1] % reduction_n == 0 - BLOCK_N = 512 - num_blocks = triton.cdiv(x.shape[-1], BLOCK_N) - - _reduce_grouped[(num_blocks, num_groups)]( - x, - x.stride(0), - x.stride(1), - x.stride(2), # - out, - out.stride(0), - out.stride(1), # - indx, # - x.shape[0], - x.shape[-1], # - apply_swiglu, - alpha, - limit, - reduction_n, - BLOCK_N=BLOCK_N, - EVEN_N=(x.shape[-1] % BLOCK_N == 0), - K=K, # - num_warps=2, # - ) - return out - - # ----------------------------------------------------------------------------- # Triton Implementation # ----------------------------------------------------------------------------- @@ -218,6 +144,7 @@ def moe_gemm_a8w8( apply_swiglu=False, alpha=1.0, limit=1.0, + add_residual=True, unpadded_N=None, unpadded_K=None, ): @@ -335,6 +262,7 @@ def moe_gemm_a8w8( alpha, limit, reduction_n_matmul, + add_residual, routing_data.n_expts_act, config["block_m"], config["block_n"], @@ -368,6 +296,7 @@ def moe_gemm_a8w8( limit, reduction_n_reduction, out_dtype=out_dtype, + add_residual=add_residual, ) return y_final @@ -377,7 +306,7 @@ def moe_gemm_a8w8( # ----------------------------------------------------------------------------- -def swiglu_torch(a, alpha, limit): +def swiglu_torch(a, alpha, limit, add_residual=True): a_gelu = a[..., ::2] if limit is not None: a_gelu = a_gelu.clamp(max=limit) @@ -386,7 +315,10 @@ def swiglu_torch(a, alpha, limit): a_linear = a_linear.clamp(min=-limit, max=limit) out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu) - out = out_gelu * (a_linear + 1) + if add_residual: + out = out_gelu * (a_linear + 1) + else: + out = out_gelu * a_linear return out @@ -401,6 +333,7 @@ def moe_gemm_torch( apply_swiglu=False, alpha=1.0, limit=1.0, + add_residual=True, ): assert x.dtype.itemsize > 1 assert w.dtype.itemsize > 1 @@ -430,7 +363,7 @@ def moe_gemm_torch( if bias is not None: out += bias[i, :] if apply_swiglu: - out = swiglu_torch(out, alpha, limit) + out = swiglu_torch(out, alpha, limit, add_residual) if gammas is not None: out *= gammas[lo:hi, None] y[lo:hi, :] = out diff --git a/aiter/ops/triton/moe/moe_op_gemm_a8w8_blockscale.py b/aiter/ops/triton/moe/moe_op_gemm_a8w8_blockscale.py index f417827ffe..b14e5fcd32 100644 --- a/aiter/ops/triton/moe/moe_op_gemm_a8w8_blockscale.py +++ b/aiter/ops/triton/moe/moe_op_gemm_a8w8_blockscale.py @@ -7,8 +7,8 @@ from aiter.ops.triton.moe.moe_routing.routing import RoutingData from aiter.ops.triton._triton_kernels.moe.moe_op_gemm_a8w8_blockscale import ( _moe_gemm_a8w8_blockscale, - _reduce_grouped, ) +from aiter.ops.triton.moe.reduce import reduce_grouped # ----------------------------------------------------------------------------- # Matrix Multiplication + Outer Gather/Scatter @@ -114,80 +114,6 @@ def get_kernel_config(m, n, k, routing_data): return ret -def reduce_grouped( - x: torch.Tensor, - indx: torch.Tensor, - out: torch.Tensor, - apply_swiglu=False, - alpha=1.0, - limit=None, - reduction_n=1, - out_dtype: bool = None, -): - """ - In-place grouped row reduction. - - Arguments - - x: Tensor[AnyFloat] of shape [(num_groups * K), N] - - indx: Tensor[Int] of shape [num_groups, K] - - Description - For each group g in [0, num_groups), this routine sums the K rows of `x` - specified by `indx[g, :]` and overwrites the row corresponding to the first - valid (non-negative) index with the per-group sum. Accumulation is performed - in float32 for numerical stability, and the result is written back in the - dtype of `x`. - - Behavior and edge cases - - Invalid (-1) entries are skipped during accumulation and do not generate - memory traffic. If a group has no valid entries, nothing is written for - that group. - - Reduction is performed tile-by-tile along the N dimension within a single - kernel launch (persistent along N) to minimize launch overhead. - - Performance notes - - Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x), - plus index reads. With no invalid entries, this becomes (K + 1) reads/writes - of length N per group. - - Returns - - The input tensor `x` (modified in place). - """ - if indx is None and x.shape[0] == 1: - return x.squeeze(0) - if indx is not None: - num_groups = indx.shape[0] - else: - num_groups = x.shape[-2] - K = 1 if indx is None else indx.shape[1] - out_dtype = x.dtype if out_dtype is None else out_dtype - assert x.shape[-1] % reduction_n == 0 - BLOCK_N = 512 - num_blocks = triton.cdiv(x.shape[-1], BLOCK_N) - - _reduce_grouped[(num_blocks, num_groups)]( - x, - x.stride(0), - x.stride(1), - x.stride(2), # - out, - out.stride(0), - out.stride(1), # - indx, # - x.shape[0], - x.shape[-1], # - apply_swiglu, - alpha, - limit, - reduction_n, - BLOCK_N=BLOCK_N, - EVEN_N=(x.shape[-1] % BLOCK_N == 0), - K=K, # - num_warps=2, # - ) - return out - - # ----------------------------------------------------------------------------- # Triton Implementation # ----------------------------------------------------------------------------- @@ -210,6 +136,7 @@ def moe_gemm_a8w8_blockscale( apply_swiglu=False, alpha=1.0, limit=None, + add_residual=True, unpadded_N=None, unpadded_K=None, per_row_x_scale=False, @@ -326,6 +253,7 @@ def moe_gemm_a8w8_blockscale( alpha, limit, reduction_n_matmul, + add_residual, routing_data.n_expts_act, config["block_m"], config["block_n"], @@ -362,6 +290,7 @@ def moe_gemm_a8w8_blockscale( limit, reduction_n_reduction, out_dtype=out_dtype, + add_residual=add_residual, ) return y_final @@ -371,7 +300,7 @@ def moe_gemm_a8w8_blockscale( # ----------------------------------------------------------------------------- -def swiglu_torch(a, alpha, limit): +def swiglu_torch(a, alpha, limit, add_residual=True): a_gelu = a[..., ::2] if limit is not None: a_gelu = a_gelu.clamp(max=limit) @@ -380,7 +309,10 @@ def swiglu_torch(a, alpha, limit): a_linear = a_linear.clamp(min=-limit, max=limit) out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu) - out = out_gelu * (a_linear + 1) + if add_residual: + out = out_gelu * (a_linear + 1) + else: + out = out_gelu * a_linear return out @@ -395,6 +327,7 @@ def moe_gemm_torch( apply_swiglu=False, alpha=1.0, limit=None, + add_residual=True, ): assert x.dtype.itemsize > 1 assert w.dtype.itemsize > 1 @@ -424,7 +357,7 @@ def moe_gemm_torch( if bias is not None: out += bias[i, :] if apply_swiglu: - out = swiglu_torch(out, alpha, limit) + out = swiglu_torch(out, alpha, limit, add_residual) if gammas is not None: out *= gammas[lo:hi, None] y[lo:hi, :] = out diff --git a/aiter/ops/triton/moe/moe_op_gemm_int8_smoothquant.py b/aiter/ops/triton/moe/moe_op_gemm_int8_smoothquant.py index dbc51ec378..4b710afd45 100644 --- a/aiter/ops/triton/moe/moe_op_gemm_int8_smoothquant.py +++ b/aiter/ops/triton/moe/moe_op_gemm_int8_smoothquant.py @@ -8,8 +8,9 @@ from aiter.ops.triton.utils.device_info import get_num_sms from aiter.ops.triton._triton_kernels.moe.moe_op_gemm_int8_smoothquant import ( _moe_gemm_int8_smoothquant, - _reduce_grouped, ) +from aiter.ops.triton.moe.reduce import reduce_grouped +from aiter.ops.triton.utils._triton.arch_info import get_arch # ----------------------------------------------------------------------------- # Matrix Multiplication + Outer Gather/Scatter @@ -114,7 +115,7 @@ def get_kernel_config(m, n, k, routing_data): block_k = 256 num_warps = 4 num_stages = 2 - kpack = 2 + kpack = 2 if get_arch() == "gfx942" else 1 grid_m = routing_data.n_blocks(m, block_m) grid_n = triton.cdiv(n, block_n) @@ -147,83 +148,6 @@ def get_kernel_config(m, n, k, routing_data): return ret -def reduce_grouped( - x: torch.Tensor, - indx: torch.Tensor, - out: torch.Tensor, - alpha=1.0, - limit=1.0, - reduction_n=1, - apply_activation: bool = False, - out_dtype: torch.dtype = None, - add_residual: bool = False, -): - """ - In-place grouped row reduction. - - Arguments - - x: Tensor[AnyFloat] of shape [(num_groups * K), N] - - indx: Tensor[Int] of shape [num_groups, K] - - Description - For each group g in [0, num_groups), this routine sums the K rows of `x` - specified by `indx[g, :]` and overwrites the row corresponding to the first - valid (non-negative) index with the per-group sum. Accumulation is performed - in float32 for numerical stability, and the result is written back in the - dtype of `x`. - - Behavior and edge cases - - Invalid (-1) entries are skipped during accumulation and do not generate - memory traffic. If a group has no valid entries, nothing is written for - that group. - - Reduction is performed tile-by-tile along the N dimension within a single - kernel launch (persistent along N) to minimize launch overhead. - - Performance notes - - Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x), - plus index reads. With no invalid entries, this becomes (K + 1) reads/writes - of length N per group. - - Returns - - The input tensor `x` (modified in place). - """ - if indx is None and x.shape[0] == 1: - return x.squeeze(0) - - if indx is not None: - num_groups = indx.shape[0] - else: - num_groups = x.shape[-2] - K = 1 if indx is None else indx.shape[1] - out_dtype = x.dtype if out_dtype is None else out_dtype - assert x.shape[-1] % reduction_n == 0 - BLOCK_N = 512 - num_blocks = triton.cdiv(x.shape[-1], BLOCK_N) - - _reduce_grouped[(num_blocks, num_groups)]( - x, - x.stride(0), - x.stride(1), - x.stride(2), - out, - out.stride(0), - out.stride(1), - indx, - x.shape[0], - x.shape[-1], - alpha, - limit, - reduction_n, - BLOCK_N=BLOCK_N, - EVEN_N=(x.shape[-1] % BLOCK_N == 0), - K=K, - num_warps=2, - ADD_RESIDUAL=add_residual, - APPLY_ACTIVATION=apply_activation, - ) - return out - - # ----------------------------------------------------------------------------- # Triton Implementation # ----------------------------------------------------------------------------- @@ -372,11 +296,10 @@ def moe_gemm_int8_smoothquant( y, group_indx, y_final, + apply_activation and (config["split_k"] > 1), # apply activation if split_k > 1 alpha, limit, reduction_n_reduction, - apply_activation=(alpha != 0) - and (config["split_k"] > 1), # apply activation if split_k > 1 out_dtype=out_dtype, add_residual=add_residual, ) diff --git a/aiter/ops/triton/moe/reduce.py b/aiter/ops/triton/moe/reduce.py new file mode 100644 index 0000000000..6a17af6813 --- /dev/null +++ b/aiter/ops/triton/moe/reduce.py @@ -0,0 +1,73 @@ +import torch +import triton +from aiter.ops.triton._triton_kernels.moe.reduce import _reduce_grouped + + +def reduce_grouped( + x: torch.Tensor, + indx: torch.Tensor, + out: torch.Tensor, + apply_swiglu=False, + alpha=1.0, + limit=1.0, + reduction_n=1, + out_dtype=None, + add_residual: bool = True, +): + """ + Grouped row reduction used during moe scatter and also compatible with split-k reduce. + + Arguments + - x: Tensor[AnyFloat] of shape [(num_groups * K), N] + - indx: Tensor[Int] of shape [num_groups, K] + + Description + For each group g in [0, num_groups), this routine sums the K rows of `x` + specified by `indx[g, :]`. Accumulation is performed + in float32 for numerical stability, and the result is written back in the + dtype of `x`. + + Performance notes + - Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x), + plus index reads. With no invalid entries, this becomes (K + 1) reads/writes + of length N per group. + + Returns + - The output tensor `out`. + """ + + if indx is None and x.shape[0] == 1: + return x.squeeze(0) + if indx is not None: + num_groups = indx.shape[0] + else: + num_groups = x.shape[-2] + K = 1 if indx is None else indx.shape[1] + out_dtype = x.dtype if out_dtype is None else out_dtype + assert x.shape[-1] % reduction_n == 0 + BLOCK_N = 512 + num_blocks = triton.cdiv(x.shape[-1], BLOCK_N) + + _reduce_grouped[(num_blocks * num_groups,)]( + x, + x.stride(0), + x.stride(1), + x.stride(2), # + out, + out.stride(0), + out.stride(1), # + indx, # + x.shape[0], + x.shape[-1], + num_blocks, + apply_swiglu, + alpha, + limit, + reduction_n, + BLOCK_N=BLOCK_N, + EVEN_N=(x.shape[-1] % BLOCK_N == 0), + K=K, # + ADD_RESIDUAL=add_residual, + num_warps=2, # + ) + return out