diff --git a/aiter/ops/triton/_triton_kernels/moe_op_gemm_a8w4.py b/aiter/ops/triton/_triton_kernels/moe_op_gemm_a8w4.py new file mode 100644 index 0000000000..ffc413a97f --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/moe_op_gemm_a8w4.py @@ -0,0 +1,505 @@ +# adapted from triton_kernels package +# original code https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py + +import torch +import triton +import triton.language as tl +from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid +from aiter.ops.triton._triton_kernels.quant_moe import _compute_static_fp8_quant + + +def matmul_launch_metadata(grid, kernel, args): + ret = dict() + M, N, K = None, args["N"], args["K"] + Y, X, W = args["Y"], args["X"], args["W"] + hist = args["ExptHist"] + if hist is not None: + n_rows = int(hist.float().mean()) + n_tokens = float(hist.sum()) + n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum() + else: + n_tokens = None + n_w_bytes = W.numel() * W.element_size() + repr = lambda s, x: f"{s}={x}" if x is not None else f"E_{len(hist)}({s})={n_rows}" + nbits = X.dtype.itemsize * 8 + ret["name"] = f"{kernel.name} [{repr('M', M)}, {repr('N', N)}, {repr('K', K)}]" + if args["B"] is not None: + ret["name"] += "_bias" + if args["APPLY_SWIGLU"]: + ret["name"] += "_swiglu" + if args["Quant_static_scale"] is not None: + ret["name"] += "_quant" + + fM = n_tokens + fK = K if K is not None else n_tokens + ret[f"flops{nbits}"] = 2.0 * fM * N * fK + + gindx = args.get("GatherIndx", None) + # sindx = args.get("WriteBackIndx", None) + n_x_bytes = X.numel() * X.element_size() + n_y_bytes = Y.numel() * Y.element_size() + if hist is not None: + assert n_tokens is not None + n_expts_act = args["N_EXPTS_ACT"] + + if gindx is not None: + # recreate inverse GatherIndx. + dst = torch.full_like(gindx, -1) + idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32) + mask = gindx != -1 + dst[gindx[mask]] = idx[mask] + n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum() + else: + n_read_rows = n_tokens + n_x_bytes = n_read_rows * X.shape[-1] * X.element_size() + n_y_bytes = n_tokens * Y.shape[-1] * Y.element_size() + ret["bytes"] = int(n_x_bytes + n_y_bytes + n_w_bytes) + + return ret + + +# TODO: using aiter swizzle instead can lead to perf degradation in rare cases +@triton.jit +def xcd_swizzle(pid, domain_size, XCD_SWIZZLE: tl.constexpr): + """ + Swizzle the program id based on integer XCD_SWIZZLE. + This is useful for reording how blocks are ordered. A scheduler may, for example, + assign sequential blocks 0, 1, 2, 3, ..., 8, 9, 10.. to its 8 hardware units 0, 1, 2, 3, ..., 0, 1, 2. + This pattern may not be ideal for memory access, and it may be better to swizzle so the assignment + becomes 0, 0, 0, 0, ..., 1, 1, 1, ... In the swizzled arrangement, sequential blocks are assigned to + the same hardware unit. + """ + # Number of pids per group in the new arrangement + pids_per_group = domain_size // XCD_SWIZZLE + extra_pid_groups = domain_size % XCD_SWIZZLE + + # Compute current current and local pid within the group + group = pid % XCD_SWIZZLE + local_pid = pid // XCD_SWIZZLE + + # Calculate new pid based on the new grouping + new_pid = group * pids_per_group + min(group, extra_pid_groups) + local_pid + return new_pid + + +@triton.jit +def unswizzle_mx_scale_cdna4( + x, + BLOCK_N: tl.constexpr, + MX_SCALE_BLOCK_K: tl.constexpr, + N_PRESHUFFLE_FACTOR: tl.constexpr = 32, +): + x = x.reshape(BLOCK_N // N_PRESHUFFLE_FACTOR, MX_SCALE_BLOCK_K // 8, 4, 16, 2, 2, 1) + x = x.permute(0, 5, 3, 1, 4, 2, 6) + x = x.reshape(BLOCK_N, MX_SCALE_BLOCK_K) + return x + + +@triton.jit +def clip(x, limit, clip_lower: tl.constexpr): + res = tl.minimum(x, limit) + if clip_lower: + res = tl.maximum(-limit, res) + return res + + +@triton.jit +def _swiglu(input, alpha, limit): + gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2))) + gelu = gelu.to(tl.float32) + if limit is not None: + gelu = clip(gelu, limit, clip_lower=False) + linear = linear.to(tl.float32) + if limit is not None: + linear = clip(linear, limit, clip_lower=True) + s = gelu / (1 + tl.exp2(-1.44269504089 * alpha * gelu)) + return tl.fma(s, linear, s) # (s * (linear + 1)) + + +@triton.jit +def _reduce_grouped( + X, + stride_xb: tl.uint64, + stride_xm: tl.uint64, + stride_xn, # + Out, + stride_om: tl.uint64, + stride_on, # output tensor + InIndx, + B, + N, # + # fused activation function + APPLY_SWIGLU: tl.constexpr, + alpha, + limit, + ACTIVATION_REDUCTION_N: tl.constexpr, + K: tl.constexpr, + BLOCK_N: tl.constexpr, + EVEN_N: tl.constexpr, +): + + pid_t = tl.program_id(1) + pid_n = tl.program_id(0) + + BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N + start = pid_t * K + # load indices into a tuple + if InIndx is None: + indxs = (pid_t,) + else: + indxs = () + for i in tl.static_range(0, K): + indxs = indxs + (tl.load(InIndx + start + i),) + XPtrs = X + (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) * stride_xn + OutPtrs = Out + (pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT)) * stride_on + + acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32) + x_n_mask = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) < N + # accumulate contributions for this tile + for i in tl.static_range(0, K): + curr = tl.zeros([BLOCK_N], dtype=tl.float32) + # iterate over split_k partial values + for b in tl.range(0, B): + x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb + if EVEN_N: + vals = tl.load(x_row_ptr) + else: + vals = tl.load(x_row_ptr, mask=x_n_mask, other=0.0) + vals = vals.to(tl.float32) + curr += vals + + # apply nonlinearity to split-k output + if APPLY_SWIGLU: + curr = _swiglu(curr[None, :], alpha, limit) + curr = tl.reshape(curr, [curr.shape[-1]]) + # update final accumulator + acc += curr + # Compute per-32-col MXFP scales for this tile if requested + Nrem = N // ACTIVATION_REDUCTION_N + + # write-back for this tile + out_ptr = OutPtrs + pid_t * stride_om + if EVEN_N: + tl.store(out_ptr, acc) + else: + out_n_mask = pid_n * BLOCK_N_OUT + tl.arange(0, BLOCK_N_OUT) < Nrem + tl.store(out_ptr, acc, mask=out_n_mask) + + +@triton.jit(launch_metadata=matmul_launch_metadata) +def _moe_gemm_a8w4( + Y, + stride_y_k, + stride_y_m, + stride_y_n, + X, + stride_x_m, + stride_x_k, + XMxScale, + stride_x_mx_m, + stride_x_mx_k, + W, + stride_w_e, + stride_w_k, + stride_w_n, + WMxScale, + stride_w_mx_e, + stride_w_mx_k, + stride_w_mx_n, + X_static_scale, + Quant_static_scale, + B, + stride_b_e, # Bias + Gammas, + N, + K, # shapes + # expt data + GatherIndx, + ExptHist, + ExptOffs, + ExptOffsSum, + ExptData, + # true grid size + grid_m, + grid_n, + # fused activation function + APPLY_SWIGLU: tl.constexpr, + alpha, + limit, + ACTIVATION_REDUCTION_N: tl.constexpr, + # MoE config + N_EXPTS_ACT: tl.constexpr, + # optimization config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + XCD_SWIZZLE: tl.constexpr, + # One of ["CDNA4", None] + SWIZZLE_MX_SCALE: tl.constexpr, + EVEN_K: tl.constexpr, + MASK_K_LIMIT: tl.constexpr, + SPLIT_K: tl.constexpr, + W_CACHE_MODIFIER: tl.constexpr, + UPCAST_INDICES: tl.constexpr = False, +): + + tl.assume(stride_y_k >= 0) + tl.assume(stride_y_m >= 0) + tl.assume(stride_y_n >= 0) + tl.assume(stride_x_m >= 0) + tl.assume(stride_x_k >= 0) + tl.assume(stride_w_e >= 0) + tl.assume(stride_w_k >= 0) + tl.assume(stride_w_n >= 0) + if stride_x_mx_m is not None: + tl.assume(stride_x_mx_m >= 0) + if stride_x_mx_k is not None: + tl.assume(stride_x_mx_k >= 0) + if stride_w_mx_e is not None: + tl.assume(stride_w_mx_e >= 0) + if stride_w_mx_k is not None: + tl.assume(stride_w_mx_k >= 0) + if stride_w_mx_n is not None: + tl.assume(stride_w_mx_n >= 0) + if B is not None: + tl.assume(stride_b_e >= 0) + tl.assume(grid_m >= 0) + tl.assume(grid_n >= 0) + + is_x_microscaled: tl.constexpr = XMxScale is not None + MX_PACK_DIVISOR: tl.constexpr = 32 + w_type: tl.constexpr = W.dtype.element_ty + tl.static_assert(w_type == tl.uint8, "mx_weight_ptr must be uint8 or fp8") + tl.static_assert( + WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8" + ) + tl.static_assert( + BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR" + ) + x_type: tl.constexpr = X.dtype.element_ty + if is_x_microscaled: + tl.static_assert(x_type == tl.float8e4nv, "mx_act_ptr must be float8e4nv") + tl.static_assert( + XMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8" + ) + + OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N + yN = N // ACTIVATION_REDUCTION_N + + pid = tl.program_id(0) + if ExptOffsSum is not None and XCD_SWIZZLE > 1: + # Determine how much padding there is on the expert data. This allows us to + # know the true grid size and avoid processing padding tiles. + padding_m = grid_m - tl.load(ExptOffsSum) + else: + padding_m: tl.constexpr = 0 + + index_type: tl.constexpr = tl.int64 if UPCAST_INDICES else tl.int32 + + unpadded_m = grid_m - padding_m + tl.assume(unpadded_m >= 0) + total_actual_tiles = unpadded_m * grid_n * SPLIT_K + if padding_m > 0 and pid >= total_actual_tiles: + return + + # swizzle program ids + pid_emnk = pid + if XCD_SWIZZLE != 1: + pid_emnk = xcd_swizzle(pid_emnk, total_actual_tiles, XCD_SWIZZLE) + # pid_e = pid_emnk // (unpadded_m * grid_n * SPLIT_K) + pid_mnk = pid_emnk % (unpadded_m * grid_n * SPLIT_K) + pid_k = pid_mnk % SPLIT_K + pid_mn = pid_mnk // SPLIT_K + pid_m, pid_n = pid_grid(pid_mn, unpadded_m, grid_n, GROUP_M) + # For split-k, advance to the output k slice + if SPLIT_K > 1: + Y += pid_k.to(index_type) * stride_y_k + # unpack expert data + expt_data = tl.load(ExptData + pid_m) + if expt_data == -1: + return + expt_id = expt_data & 0x0000FFFF + block_id = expt_data >> 16 + M = tl.load(ExptHist + expt_id) + start_m = tl.load(ExptOffs + expt_id) + expt_id, block_id = expt_id.to(index_type), block_id.to(index_type) + start_m = start_m.to(index_type) + pid_n, pid_k = pid_n.to(index_type), pid_k.to(index_type) + + # A pointers + offs_x_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M) + offs_x_m = tl.max_contiguous(tl.multiple_of(offs_x_m % M, BLOCK_M), BLOCK_M) + if GatherIndx is None: + X += start_m * stride_x_m + else: + GatherIndx += start_m + # no needs to bounds-check here because `offs_x_m` wraps around M dim + offs_x_m = tl.load(GatherIndx + offs_x_m) // N_EXPTS_ACT + offs_x_k = BLOCK_K * pid_k + tl.arange(0, BLOCK_K) + XPtrs = ( + X + + offs_x_m.to(index_type)[:, None] * stride_x_m + + offs_x_k.to(index_type)[None, :] * stride_x_k + ) + + W_K_DIVISOR: tl.constexpr = 2 + W_N_DIVISOR: tl.constexpr = 1 + PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K // W_K_DIVISOR + PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_N_DIVISOR + MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR + + WMxScale += expt_id * stride_w_mx_e + if SWIZZLE_MX_SCALE == "CDNA4_SCALE": + tl.static_assert(stride_w_mx_k is not None) + tl.static_assert(stride_w_mx_n is not None) + NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 32 + PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * NON_K_PRESHUFFLE_BLOCK_SIZE + SCALE_BLOCK_N: tl.constexpr = BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE + else: + PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K + SCALE_BLOCK_N: tl.constexpr = BLOCK_N + offs_w_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N + offs_w_n_scale = tl.max_contiguous( + tl.multiple_of(offs_w_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N + ) + # K dimension must be the last dimension for the scales + offs_w_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK) + WMxScalePtrs = ( + WMxScale + + offs_w_k_scale.to(index_type)[None, :] * stride_w_mx_k + + offs_w_n_scale.to(index_type)[:, None] * stride_w_mx_n + ) + + # B pointers + offs_w_n = pid_n * PACKED_BLOCK_N_W + tl.arange(0, PACKED_BLOCK_N_W) + offs_w_n = tl.max_contiguous( + tl.multiple_of(offs_w_n % (N // W_N_DIVISOR), PACKED_BLOCK_N_W), + PACKED_BLOCK_N_W, + ) + offs_w_k = PACKED_BLOCK_K_W * pid_k + tl.arange(0, PACKED_BLOCK_K_W) + W += expt_id * stride_w_e + WPtrs = W + ( + offs_w_k.to(index_type)[:, None] * stride_w_k + + offs_w_n.to(index_type)[None, :] * stride_w_n + ) + + if is_x_microscaled: + if GatherIndx is None: + XMxScale += start_m * stride_x_mx_m + offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K) + XMxScalePtrs = ( + XMxScale + + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k + ) + + num_k_iter = tl.cdiv(K, BLOCK_K * SPLIT_K) + if not EVEN_K: + num_k_iter -= 1 + + # compute output + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(num_k_iter): + x = tl.load(XPtrs) + w = tl.load(WPtrs, cache_modifier=W_CACHE_MODIFIER) + + if is_x_microscaled: + x_scales = tl.load(XMxScalePtrs) + else: + x_scales = tl.full((BLOCK_M, MX_SCALE_BLOCK_K), 127, dtype=tl.uint8) + if SWIZZLE_MX_SCALE == "CDNA4_SCALE": + w_scales = unswizzle_mx_scale_cdna4( + tl.load(WMxScalePtrs, cache_modifier=W_CACHE_MODIFIER), + BLOCK_N, + MX_SCALE_BLOCK_K, + ) + else: + w_scales = tl.load(WMxScalePtrs) + + acc = tl.dot_scaled( + x, x_scales, "e4m3", w, w_scales, "e2m1", acc=acc, fast_math=True + ) + + WMxScalePtrs += (PACKED_MX_BLOCK * SPLIT_K) * stride_w_mx_k + if is_x_microscaled: + XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k + + XPtrs += (BLOCK_K * SPLIT_K) * stride_x_k + WPtrs += (PACKED_BLOCK_K_W * SPLIT_K) * stride_w_k + + if not EVEN_K: + mask_x_k = offs_x_k < MASK_K_LIMIT + mask_w_k = offs_w_k < (MASK_K_LIMIT // W_K_DIVISOR) + if SWIZZLE_MX_SCALE is None: + mask_w_k_scale = offs_w_k_scale * MX_PACK_DIVISOR < MASK_K_LIMIT + if is_x_microscaled: + mask_x_k_scale = offs_x_k_scale * MX_PACK_DIVISOR < MASK_K_LIMIT + + x = tl.load(XPtrs, mask=mask_x_k[None, :], other=0.0) + w = tl.load( + WPtrs, mask=mask_w_k[:, None], other=0, cache_modifier=W_CACHE_MODIFIER + ) + + if is_x_microscaled: + x_scales = tl.load(XMxScalePtrs, mask=mask_x_k_scale[None, :]) + else: + x_scales = tl.full((BLOCK_M, MX_SCALE_BLOCK_K), 127, dtype=tl.uint8) + if SWIZZLE_MX_SCALE == "CDNA4_SCALE": + w_scales = unswizzle_mx_scale_cdna4( + tl.load(WMxScalePtrs, cache_modifier=W_CACHE_MODIFIER), + BLOCK_N, + MX_SCALE_BLOCK_K, + ) + else: + w_scales = tl.load(WMxScalePtrs, mask=mask_w_k_scale[None, :]) + + acc = tl.dot_scaled( + x, x_scales, "e4m3", w, w_scales, "e2m1", acc=acc, fast_math=True + ) + + # scalar fp8 scale + if X_static_scale is not None: + acc = acc * tl.load(X_static_scale) + # bias + offs_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M) + offs_y_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_y_n < N + if B is not None: + BPtrs = B + expt_id * stride_b_e + offs_y_n + if pid_k == 0: + bias = tl.load(BPtrs, mask=mask_n, other=0, cache_modifier=W_CACHE_MODIFIER) + else: + bias = tl.full([BLOCK_N], 0, dtype=tl.float32) + acc = acc + bias[None, :] + if APPLY_SWIGLU and SPLIT_K == 1: + out = _swiglu(acc, alpha, limit) + tl.static_assert( + out.shape[1] == OUT_BLOCK_N, + f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})", + ) + offs_y_n = OUT_BLOCK_N * pid_n + tl.arange(0, OUT_BLOCK_N) + mask_n = offs_y_n < yN + else: + tl.static_assert( + ACTIVATION_REDUCTION_N == 1, + "Activation reduction must be 1 if no activation fn is provided", + ) + out = acc + if Gammas is not None: + gammas = tl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0) + out *= gammas[:, None] + # quant + if Quant_static_scale is not None: + out = _compute_static_fp8_quant(out, tl.load(Quant_static_scale)) + # write-back + Y += start_m * stride_y_m + offs_y_m = offs_m + YPtrs = ( + Y + + offs_y_m.to(index_type)[:, None] * stride_y_m + + offs_y_n.to(index_type)[None, :] * stride_y_n + ) + mask = mask_m[:, None] & mask_n[None, :] + tl.store(YPtrs, out, mask=mask) diff --git a/aiter/ops/triton/_triton_kernels/moe_routing/bitmatrix.py b/aiter/ops/triton/_triton_kernels/moe_routing/bitmatrix.py new file mode 100644 index 0000000000..de4ca4a6c4 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/moe_routing/bitmatrix.py @@ -0,0 +1,104 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def vpopc(x): + """ + Vertical popcount + Input x : uint32[..., N] + Output y : uint32[..., 32] + semantics : y[..., i] = sum_j((x[..., j] >> i) & 1) + credits: @apgoucher + """ + + tl.static_assert( + x.dtype == tl.uint32, "x should consist of 32-bit unsigned integers" + ) + + BLOCK_N: tl.constexpr = x.shape[-1] # summation axis + BATCHES: tl.constexpr = x.numel // BLOCK_N # number of batches + if BLOCK_N >= 8: + sa1: tl.constexpr = 8 + else: + sa1: tl.constexpr = BLOCK_N + # create 8-way sums in 4-bit fields: + y = tl.reshape(x, [BATCHES, BLOCK_N // sa1, sa1, 1]) + y = (y >> tl.arange(0, 4)[None, None, None, :]) & 0x11111111 + y = tl.sum(y, 2) # [BATCHES, BLOCK_N // sa1, 4] + if BLOCK_N >= 128: + sa2: tl.constexpr = 16 + else: + sa2: tl.constexpr = BLOCK_N // sa1 + # create 128-way sums in 8-bit fields: + y = tl.reshape(y, [BATCHES, BLOCK_N // (sa1 * sa2), sa2, 1, 4]) + y = (y >> (4 * tl.arange(0, 2))[None, None, None, :, None]) & 0x0F0F0F0F + y = tl.sum(y, 2) # [BATCHES, BLOCK_N // (sa1 * sa2), 2, 4] + sa3: tl.constexpr = BLOCK_N // (sa1 * sa2) + # create N-way sums in 32-bit fields: + y = tl.reshape(y, [BATCHES, 1, sa3, 8]) + y = (y >> (8 * tl.arange(0, 4))[None, :, None, None]) & 0x000000FF + y = tl.sum(y, 2) # [BATCHES, 4, 8] + y = tl.reshape(y, x.shape[:-1] + [32]) + return y + + +@triton.jit +def _sum_bitmatrix_memset(Ret, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + tl.store(Ret + offs, 0) + + +@triton.jit +def _sum_bitmatrix_rows( + B, + shape_bm, + stride_bm, + stride_bn, # input bitmatrix + Ret, + Partials, + stride_pm, + stride_pn, + shape_pn, + num_pids_m, # outputs + BLOCK_MM: tl.constexpr, + BLOCK_M: tl.constexpr, +): + + tl.static_assert(BLOCK_MM % BLOCK_M == 0) + TILE_SIZE: tl.constexpr = BLOCK_MM // BLOCK_M + if isinstance(shape_bm, tl.tensor) and shape_bm.dtype.is_ptr(): + shape_bm = tl.load(shape_bm) + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_MM + tl.arange(0, BLOCK_MM) + offs_n = pid_n * 32 + tl.arange(0, 32) + n_rows = shape_bm + bits = tl.load( + B + pid_n * stride_bn + offs_m * stride_bm, mask=offs_m < n_rows, other=0 + ) + bits = tl.reshape(bits, [TILE_SIZE, BLOCK_M]) + ret = vpopc(bits) # [TILE_SIZE, 32] + + offs_t = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE) + + tl.atomic_add(Ret + offs_n, tl.sum(ret, 0), sem="relaxed") + + curr = tl.cumsum(ret, 0) - ret + tl.atomic_add( + Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, + curr, + sem="relaxed", + ) + curr = tl.sum(ret, 0, keep_dims=True) + for i in range(pid_m + 1, num_pids_m): + offs_t = i * TILE_SIZE + tl.arange(0, TILE_SIZE) + tl.atomic_add( + Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, + curr, + sem="relaxed", + ) + + # tl.store(Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, ret) diff --git a/aiter/ops/triton/_triton_kernels/moe_routing/expt_data.py b/aiter/ops/triton/_triton_kernels/moe_routing/expt_data.py new file mode 100644 index 0000000000..7c20843789 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/moe_routing/expt_data.py @@ -0,0 +1,83 @@ +import triton +import triton.language as tl + + +@triton.jit +def _cdiv_pow2(n, log2_k): + return (n + ((1 << log2_k) - 1)) >> log2_k + + +@triton.jit +def _expt_data_compute_stage1( + pid, + Hist, + n_expts_tot, + TokenStart, + TileStart, + MDTileInfo, + max_num_tiles, + n_gates, + tile_dim_log2: tl.constexpr, + BLOCK: tl.constexpr, + EQUAL_BLOCK: tl.constexpr, +): + if EQUAL_BLOCK: + offs_n = tl.arange(0, BLOCK) + hist_token = tl.load(Hist + offs_n) + hist_tile = _cdiv_pow2(hist_token, tile_dim_log2) + token_starts = tl.cumsum(hist_token, 0) - hist_token + tile_starts = tl.cumsum(hist_tile, 0) - hist_tile + tl.store(TokenStart + offs_n, token_starts) + tl.store(TileStart + offs_n, tile_starts) + else: + token_acc = tl.zeros([BLOCK], dtype=TokenStart.dtype.element_ty) + tile_acc = tl.zeros([BLOCK], dtype=TileStart.dtype.element_ty) + offs_n = tl.arange(0, BLOCK) + for i in range(0, n_expts_tot, BLOCK): + mask_n = offs_n < n_expts_tot + hist_token = tl.load(Hist + offs_n, mask=mask_n, other=0) + hist_tile = _cdiv_pow2(hist_token, tile_dim_log2) + token_starts = tl.cumsum(hist_token, 0) - hist_token + token_acc + tile_starts = tl.cumsum(hist_tile, 0) - hist_tile + tile_acc + token_acc += tl.sum(hist_token, 0) + tile_acc += tl.sum(hist_tile, 0) + tl.store(TokenStart + offs_n, token_starts) + tl.store(TileStart + offs_n, tile_starts) + offs_n += BLOCK + + if pid == 0: + tl.store(TokenStart + n_expts_tot, n_gates) + + hist_tok_last = tl.load(Hist + n_expts_tot - 1) + hist_tile_last = _cdiv_pow2(hist_tok_last, tile_dim_log2) + tile_off_last = tl.load(TileStart + n_expts_tot - 1) + hist_tile_last + tl.store(TileStart + n_expts_tot, tile_off_last) + + MEMSET_BLOCK: tl.constexpr = 16 + for block_off in range(tile_off_last, max_num_tiles, MEMSET_BLOCK): + block_offs = block_off + tl.arange(0, MEMSET_BLOCK) + tl.store( + MDTileInfo + block_offs, 0xFFFFFFFF, mask=block_offs < max_num_tiles + ) + + +@triton.jit +def _expt_data_compute_stage2( + pid, Hist, TileStart, TileInfo, tile_dim_log2: tl.constexpr +): + + expt_id = pid + + n_tokens = tl.load(Hist + expt_id) + if n_tokens == 0: + return + BLOCK: tl.constexpr = 8 + n_blocks = _cdiv_pow2(n_tokens, tile_dim_log2) + TileInfo += tl.load(TileStart + expt_id) + + n_blocks = _cdiv_pow2(n_tokens, tile_dim_log2) + block_offs = tl.arange(0, BLOCK) + for i in range(0, n_blocks, BLOCK): + data = (block_offs << 16) + expt_id + tl.store(TileInfo + block_offs, data, mask=block_offs < n_blocks) + block_offs += BLOCK diff --git a/aiter/ops/triton/_triton_kernels/moe_routing/routing.py b/aiter/ops/triton/_triton_kernels/moe_routing/routing.py new file mode 100644 index 0000000000..902279b7a4 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/moe_routing/routing.py @@ -0,0 +1,150 @@ +import triton +import triton.language as tl + +from aiter.ops.triton._triton_kernels.moe_routing.expt_data import ( + _expt_data_compute_stage1, + _expt_data_compute_stage2, +) + + +@triton.jit +def _keyed_add(x, y): + + # we keep the key in the upper 16 bits of a uint32: + key_mask: tl.constexpr = 0xFFFF0000 + + kx = x & key_mask + ky = y & key_mask + z = tl.where(kx == ky, x + y - kx, y) + return z + + +@triton.jit +def _routing_compute_indx( + pid_m, + GatherIndx, + ScatterIndx, + GateScal, + ExptScal, + ExptIndx, + PartialOffs, + stride_pm, + stride_pn, + TokensStart, + n_gates, + BLOCK_M: tl.constexpr, + EVEN_M: tl.constexpr, + N_EXPTS_ACT: tl.constexpr, +): + + tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768) + + local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M) + offs = pid_m * BLOCK_M * N_EXPTS_ACT + local_offs + if EVEN_M: + expert = tl.load(ExptIndx + offs).to(tl.uint32) + else: + expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32) + + # stable-sort by expert ID: + kv_pairs = ((expert << 16) | local_offs).to(tl.uint32) + kv_pairs = tl.sort(kv_pairs, 0) + expert = kv_pairs >> 16 + offs = pid_m * BLOCK_M * N_EXPTS_ACT + (kv_pairs & 0xFFFF) + + if EVEN_M: + mask = expert != 0xFFFF + gate_scal = tl.load(ExptScal + offs) + + # compute run lengths in expert-sorted order: + x = kv_pairs & 0xFFFF0000 | 0x00000001 + expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add) + exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xFFFF + + gates = tl.load(PartialOffs + pid_m * stride_pm + expert * stride_pn) + gates += tl.load(TokensStart + expert) + gates += exclusive_run_lengths + + tl.store(ScatterIndx + offs, gates) + tl.store(GatherIndx + gates, offs) + tl.store(GateScal + gates, gate_scal) + else: + mask = expert != 0xFFFF + gate_scal = tl.load(ExptScal + offs, mask=mask) + + # compute run lengths in expert-sorted order: + x = kv_pairs & 0xFFFF0000 | 0x00000001 + expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add) + exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xFFFF + + gates = tl.load(PartialOffs + pid_m * stride_pm + expert * stride_pn, mask=mask) + gates += tl.load(TokensStart + expert, mask=mask) + gates += exclusive_run_lengths + + tl.store(ScatterIndx + offs, gates, mask=mask) + tl.store(GatherIndx + gates, offs, mask=mask) + tl.store(GateScal + gates, gate_scal, mask=mask) + + +@triton.jit +def _combined_routing( + GatherIndx, + ScatterIndx, + GateScal, + ExptScal, + ExptIndx, + PartialOffs, + stride_pm, + stride_pn, + n_gates, + BLOCK_M: tl.constexpr, + EVEN_M: tl.constexpr, + N_EXPTS_ACT: tl.constexpr, + ExpertHist, + n_expts_tot, + TokenStart, + TileStart, + blocks1a, + MDTileInfo, + max_num_tiles, + tile_dim_log2: tl.constexpr, + BLOCK_A: tl.constexpr, + EQUAL_A: tl.constexpr, +): + + pid = tl.program_id(0) + + _expt_data_compute_stage1( + pid, + ExpertHist, + n_expts_tot, + TokenStart, + TileStart, + MDTileInfo, + max_num_tiles, + n_gates, + tile_dim_log2, + BLOCK_A, + EQUAL_A, + ) + + if pid < blocks1a: + _expt_data_compute_stage2(pid, ExpertHist, TileStart, MDTileInfo, tile_dim_log2) + else: + pid -= blocks1a + _routing_compute_indx( + pid, + GatherIndx, + ScatterIndx, + GateScal, + ExptScal, + ExptIndx, + PartialOffs, + stride_pm, + stride_pn, + TokenStart, + n_gates, + BLOCK_M, + EVEN_M, + N_EXPTS_ACT, + ) diff --git a/aiter/ops/triton/_triton_kernels/moe_routing/topk.py b/aiter/ops/triton/_triton_kernels/moe_routing/topk.py new file mode 100644 index 0000000000..336171cf64 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/moe_routing/topk.py @@ -0,0 +1,191 @@ +import triton +import triton.language as tl + + +@triton.jit +def get_topmask_and_fullmask(x): + tl.static_assert( + x.dtype.is_int_unsigned(), "floating-point value must be passed as bits" + ) + tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth) + fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1 + tm_arr = tl.full(x.shape, tm, dtype=x.dtype) + fm_arr = tl.full(x.shape, fm, dtype=x.dtype) + return tm_arr, fm_arr + + +@triton.jit +def fpval_to_key(x): + tm, fm = get_topmask_and_fullmask(x) + return x ^ tl.where((x & tm) != 0, fm, tm) + + +@triton.jit +def key_to_fpval(x): + tm, fm = get_topmask_and_fullmask(x) + return x ^ tl.where((x & tm) == 0, fm, tm) + + +# stable top-k tie-breaks to value with smaller index +@triton.jit +def indx_to_key(indx, N_EXPTS_PAD: tl.constexpr): + return N_EXPTS_PAD - indx + + +@triton.jit +def key_to_indx(indx, N_EXPTS_PAD: tl.constexpr): + return N_EXPTS_PAD - indx + + +@triton.jit +def streaming_topk( + X, + stride_xm, + n_expts_tot, + offs_m, + mask_m, + N_EXPTS_PAD: tl.constexpr, + N_EXPTS_ACT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth + x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}") + if x_nbits < 16: + # this ensures that we leave at least 16 bits for expert index + # even if the input dtype is smaller than 16 bits: + y_nbits: tl.constexpr = 32 + else: + y_nbits: tl.constexpr = x_nbits * 2 + x_ultype: tl.constexpr = tl.dtype(f"uint{y_nbits}") + x_dtype: tl.constexpr = X.dtype.element_ty + + # subtract 1 from loop iterations because we peel the first (masked) iteration: + loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1 + offs_x_n = loop_iterations * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_x_n[None, :] < n_expts_tot + + # first iteration: + X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :] + x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf")) + x = fpval_to_key(x.to(x_utype, bitcast=True)) + x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :] + acc = tl.topk(x, N_EXPTS_ACT, dim=1) + + # subsequent iterations: + for _i in (tl.static_range if loop_iterations <= 4 else range)(loop_iterations): + acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge + X_ptrs -= BLOCK_N + offs_x_n -= BLOCK_N + x = tl.load(X_ptrs, mask=mask_m, other=float("-inf")) + x = fpval_to_key(x.to(x_utype, bitcast=True)) + x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :] + acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1)) + + # rotate expert index into upper 16 bits: + # 0000vvvvvvvviiii --> iiii0000vvvvvvvv + acc = (acc << (y_nbits - 16)) | (acc >> 16) + # sort in ascending order of expert (descending order of key) + acc = tl.sort(acc, dim=1, descending=True) + # iiii0000vvvvvvvv --> 0000iiii: + y_indices_raw = (acc >> (y_nbits - 16)).to(tl.uint32) + y_indices = key_to_indx(y_indices_raw, N_EXPTS_PAD) + # iiii0000vvvvvvvv --> vvvvvvvv: + y_values_raw = acc.to(x_utype) + y_values = key_to_fpval(y_values_raw).to(x_dtype, bitcast=True) + + return y_values, y_indices + + +@triton.jit +def _topk( + X, + stride_xm, # inputs + Yv, + Yi, + stride_ym, # topk values/indices + USE_PROVIDED_INDX: tl.constexpr, + Bits, + stride_rm, + stride_rn, # bitmatrix + n_rows, + n_expts_tot, # shape + S, + BLOCK_S: tl.constexpr, + s_blocks, # thing to memset + SP, + BLOCK_SP: tl.constexpr, + sp_blocks, + sp_size, + APPLY_SOFTMAX: tl.constexpr, # constant + BLOCK_M: tl.constexpr, + N_EXPTS_PAD: tl.constexpr, + N_EXPTS_ACT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + + pid = tl.program_id(0) + if isinstance(n_rows, tl.tensor) and n_rows.dtype.is_ptr(): + n_rows = tl.load(n_rows) + + if pid < s_blocks: + tl.store( + S + BLOCK_S * pid + tl.arange(0, BLOCK_S), tl.zeros([BLOCK_S], tl.int32) + ) + elif pid < s_blocks + sp_blocks: + offs = BLOCK_SP * (pid - s_blocks) + tl.arange(0, BLOCK_SP) + tl.store(SP + offs, tl.zeros([BLOCK_SP], tl.int32), mask=offs < sp_size) + + if pid * BLOCK_M >= n_rows: + # early exit: + return + + tl.static_assert(BLOCK_N % 32 == 0) + tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0) + x_dtype: tl.constexpr = X.dtype.element_ty + + # load logits + offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M) + offs_y_n = tl.arange(0, N_EXPTS_ACT) + mask_m = offs_m[:, None] < n_rows + if USE_PROVIDED_INDX: + Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :] + y_indices = tl.load(Yi_ptrs, mask=mask_m) + Xv_ptrs = X + offs_m[:, None] * stride_xm + y_indices + y_values = tl.load(Xv_ptrs, mask=mask_m) + else: + y_values, y_indices = streaming_topk( + X, + stride_xm, + n_expts_tot, + offs_m, + mask_m, # + N_EXPTS_PAD, + N_EXPTS_ACT, + BLOCK_N, + ) + + # normalize selected values + if APPLY_SOFTMAX: + y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to( + x_dtype + ) + + # write back + Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :] + tl.store(Yv_ptrs, y_values, mask=mask_m) + if not USE_PROVIDED_INDX: + Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :] + tl.store(Yi_ptrs, y_indices, mask=mask_m) + + # pack into bitmatrix + y_div = y_indices // 32 + y_rem = y_indices % 32 + loop_iterations = N_EXPTS_PAD // BLOCK_N + for i in range(loop_iterations): + offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32) + y2 = tl.where( + y_div[:, :, None] == offs_r_n[None, None, :], (1 << y_rem)[:, :, None], 0 + ) + r = tl.reduce_or(y2, axis=1) + BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn + tl.store(BitsPtrs, r, mask=mask_m) diff --git a/aiter/ops/triton/_triton_kernels/quant_moe.py b/aiter/ops/triton/_triton_kernels/quant_moe.py new file mode 100644 index 0000000000..6bb628fc56 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/quant_moe.py @@ -0,0 +1,418 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _compute_static_fp8_quant(tensor, scale): + tensor = tensor.to(tl.float32) + tensor = tensor / scale + tensor = tensor.to(tl.float8e4nv) + return tensor + + +@triton.jit +def _downcast_to_static_fp8( + x_ptr, + stride_x_m, + stride_x_n, + y_ptr, + stride_y_m, + stride_y_n, + scale_ptr, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + + x_dtype: tl.constexpr = x_ptr.dtype.element_ty + tl.static_assert( + (x_dtype == tl.bfloat16) or (x_dtype == tl.float16) or (x_dtype == tl.float32), + f"{x_dtype=} must be bfloat16 or float16 or float32", + ) + + pid_m = tl.program_id(0).to(tl.int64) + pid_n = tl.program_id(1).to(tl.int64) + + start_m = pid_m * BLOCK_M + start_n = pid_n * BLOCK_N + + x_ptr += start_m * stride_x_m + start_n * stride_x_n + y_ptr += start_m * stride_y_m + start_n * stride_y_n + + offs_m = tl.arange(0, BLOCK_M)[None, :].to(tl.int64) + offs_n = tl.arange(0, BLOCK_N)[:, None].to(tl.int64) + + mask_m = start_m + offs_m < M + mask_n = start_n + offs_n < N + mask_xy = mask_m & mask_n + + offs_x = offs_m * stride_x_m + offs_n * stride_x_n + offs_y = offs_m * stride_y_m + offs_n * stride_y_n + + x = tl.load(x_ptr + offs_x, mask=mask_xy) + + y = _compute_static_fp8_quant(x, tl.load(scale_ptr)) + + tl.store(y_ptr + offs_y, y, mask=mask_xy) + + +@triton.jit +def _get_max_quant_val(dtype: tl.constexpr): + if dtype == tl.uint8: + return 6.0 + elif dtype == tl.float8e5: + return 57344.0 + elif dtype == tl.float8e4nv: + return 448.0 + else: + tl.static_assert(False, f"Invalid {dtype=}") + + +@triton.jit +def _compute_mx_quant_and_scale( + src_tensor, + valid_src_mask, + mx_tensor_dtype: tl.constexpr, + DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0, +): + is_fp8: tl.constexpr = ( + mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5 + ) + BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0] + BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1] + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // 32 + + # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16 + f32_tensor = src_tensor.to(tl.float32) + abs_tensor = tl.abs(f32_tensor) + abs_tensor = tl.where( + valid_src_mask, abs_tensor, -1.0 + ) # Don't consider padding tensors in scale computation + abs_tensor = tl.reshape( + abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32] + ) + max_val = tl.max(abs_tensor, axis=2, keep_dims=True) + dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype) + if DEQUANT_SCALE_ROUNDING_MODE == 0: + # DequantScaleRoundingMode.ROUND_UP + # compute 2 ** ceil(log2(dequant_scale)) + # Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros + # A corner case: exponent is 0xFF that will overflow but that's already + # NaN so assume we don't care. + dequant_scale_exponent = ( + dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF + ) & 0x7F800000 + else: + # DequantScaleRoundingMode.ROUND_DOWN + # compute 2 ** floor(log2(dequant_scale)) + assert DEQUANT_SCALE_ROUNDING_MODE == 1 + dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000 + dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded) + + f32_tensor = tl.reshape( + f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32] + ) + quant_tensor = f32_tensor * quant_scale + + # Reshape the tensors after scaling + quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format. + quant_tensor = tl.where(valid_src_mask, quant_tensor, 0) + dequant_scale_exponent = dequant_scale_exponent.reshape( + [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE] + ) + + # First, we simply extract the exponent part of the scales and store the result + dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8) + # Now we must convert the tensors to the mx format. + if is_fp8: + out_tensor = quant_tensor.to(mx_tensor_dtype) + else: + quant_tensor = quant_tensor.to(tl.uint32, bitcast=True) + signs = quant_tensor & 0x80000000 + exponents = (quant_tensor >> 23) & 0xFF + mantissas = quant_tensor & 0x7FFFFF + + # 0.25 <= x < 0.75 maps to 0.5, a denormal number + E8_BIAS = 127 + E2_BIAS = 1 + # Move implicit bit 1 at the beginning to mantissa for denormals + adjusted_exponents = tl.core.sub( + E8_BIAS, exponents + 1, sanitize_overflow=False + ) + mantissas = tl.where( + exponents < E8_BIAS, + (0x400000 | (mantissas >> 1)) >> adjusted_exponents, + mantissas, + ) + + # For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0. + exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS) + + # Combine sign, exponent, and mantissa, while saturating + # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right + e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7) + e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8) + + e2m1_value = tl.reshape( + e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2] + ) + evens, odds = tl.split(e2m1_value) + out_tensor = evens | (odds << 4) + + return out_tensor, dequant_scale_exponent + + +@triton.jit +def _downcast_to_mxfp( + mx_tensor_ptr, + stride_mxt_outer, + stride_mxt_quant: tl.constexpr, + mx_scale_ptr, + stride_mx_scale_outer, + stride_mx_scale_quant, + src_ptr, + stride_src_outer, + stride_src_quant, + outer_dim, + quant_dim, + BLOCK_SIZE_OUT_DIM: tl.constexpr, + BLOCK_SIZE_QUANT_DIM: tl.constexpr, + DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr, +): + + tl.static_assert( + stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1." + ) + tl.static_assert( + BLOCK_SIZE_QUANT_DIM % 32 == 0, + f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32", + ) + + # uint8 signifies two fp4 e2m1 values packed into a single byte + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + tl.static_assert( + mx_tensor_dtype == tl.uint8 + or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5), + f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.", + ) + + src_dtype: tl.constexpr = src_ptr.dtype.element_ty + tl.static_assert( + mx_scale_ptr.dtype.element_ty == tl.uint8, + f"{mx_scale_ptr.dtype.element_ty=} must be uint8", + ) + tl.static_assert( + (src_dtype == tl.bfloat16) or (src_dtype == tl.float16), + f"{src_dtype=} must be bfloat16 or float16", + ) + is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8 + + outer_block = tl.program_id(0).to(tl.int64) + quant_block = tl.program_id(1).to(tl.int64) + + K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1 + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // 32 + BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR + + start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM + start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE + start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR + start_out = outer_block * BLOCK_SIZE_OUT_DIM + + src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer + mx_scale_ptr += ( + start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer + ) + mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer + + offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64) + offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64) + offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64) + offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64) + + mask_src_quant = start_src_quant + offs_src_quant < quant_dim + mask_n = start_out + offs_outer < outer_dim + full_mask_src = mask_src_quant & mask_n + + mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR) + full_mask_mxt = mask_mxt_quant & mask_n + + scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, 32) + full_scale_mask = scale_mask_k & mask_n + + src_tensor_offsets = ( + offs_src_quant * stride_src_quant + offs_outer * stride_src_outer + ) + mx_scale_offsets = ( + offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer + ) + mx_tensor_offsets = ( + offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer + ) + src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src) + + out_tensor, scale_tensor = _compute_mx_quant_and_scale( + src_tensor, full_mask_src, mx_tensor_dtype, DEQUANT_SCALE_ROUNDING_MODE + ) + + tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask) + tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt) + + +@triton.jit +def _upcast_from_mxfp( + out_ptr, + stride_o_outer, + stride_o_quant: tl.constexpr, + mx_scale_ptr, + stride_scale_outer, + stride_scale_quant, + mx_tensor_ptr, + stride_tensor_outer, + stride_tensor_quant: tl.constexpr, + outer_dim, + quant_dim, + BLOCK_SIZE_OUT_DIM: tl.constexpr, + BLOCK_SIZE_QUANT_DIM: tl.constexpr, +): + + tl.static_assert( + stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx" + ) + tl.static_assert( + BLOCK_SIZE_QUANT_DIM % 32 == 0, "BLOCK_SIZE_K must be a multiple of 32" + ) + # uint8 signifies two fp4 e2m1 values packed into a single byte + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + dst_dtype: tl.constexpr = out_ptr.dtype.element_ty + tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16) + tl.static_assert( + mx_tensor_dtype == tl.uint8 + or ( + (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) + or mx_tensor_dtype == dst_dtype + ), + "mx_tensor_ptr must be uint8 or float8 or dst_dtype", + ) + tl.static_assert( + mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8" + ) + + # Determine if we are dealing with fp8 types. + is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8 + is_fp8: tl.constexpr = ( + mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5 + ) + K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1 + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // 32 + BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR + + # Compute starting indices for the quantized (packed) dimension and the outer dimension. + outer_block = tl.program_id(0).to(tl.int64) + quant_block = tl.program_id(1).to(tl.int64) + + start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR + start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM + start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE + start_out = outer_block * BLOCK_SIZE_OUT_DIM + + mx_tensor_ptr += ( + start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer + ) + mx_scale_ptr += ( + start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer + ) + out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant + + # Compute offsets and masks. + offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64) + offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64) + offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64) + offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64) + + mask_outer = start_out + offs_outer < outer_dim + mask_out_quant = start_out_quant + offs_out_quant < quant_dim + full_mask_out = mask_out_quant & mask_outer + + mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR) + full_mask_src = mask_src_quant & mask_outer + + mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, 32) + full_scale_mask = mask_scale & mask_outer + + tensor_offsets = ( + offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer + ) + scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer + out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer + + # Load the packed tensor and scale. + tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src) + scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask) + + # Upcast the scale to the destination type. + if dst_dtype == tl.bfloat16: + dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True) + else: + tl.static_assert(dst_dtype == tl.float16) + dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + dst_scale = dst_scale.to(tl.float16) + + # Now upcast the tensor. + if is_fp8: + dst_tensor = tensor.to(dst_dtype) + if tensor.dtype == tl.float8e5: + from_e_bits: tl.constexpr = 5 + from_m_bits: tl.constexpr = 2 + to_e_bits: tl.constexpr = 8 if dst_dtype == tl.bfloat16 else 5 + to_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10 + + # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them! + non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits + non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits + dst_tensor = tl.where( + (tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) + == non_finite_mask_src, + (dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to( + dst_dtype, bitcast=True + ), + dst_tensor, + ) + else: + assert is_fp4 + dst_bias: tl.constexpr = 127 if dst_dtype == tl.bfloat16 else 15 + dst_0p5: tl.constexpr = 16128 if dst_dtype == tl.bfloat16 else 0x3800 + dst_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10 + # e2m1 + em0 = tensor & 0x07 + em1 = tensor & 0x70 + x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ( + (tensor & 0x08).to(tl.uint16) << 12 + ) + x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ( + (tensor & 0x80).to(tl.uint16) << 8 + ) + # Three cases: + # 1) x is normal and non-zero: Correct bias + x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0) + x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1) + # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1) + # 3) x is zero, do nothing + dst_tensor = tl.interleave(x0, x1).to(dst_dtype, bitcast=True) + + # Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping. + dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32]) + dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1]) + scale = scale.reshape(dst_scale.shape) + + out_tensor = dst_tensor * dst_scale + # Correct any NaNs encoded via the scale. + out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor) + out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out) diff --git a/aiter/ops/triton/moe_op_gemm_a8w4.py b/aiter/ops/triton/moe_op_gemm_a8w4.py new file mode 100644 index 0000000000..7cc5848851 --- /dev/null +++ b/aiter/ops/triton/moe_op_gemm_a8w4.py @@ -0,0 +1,438 @@ +# adapted from triton_kernels package +# original code https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/matmul_ogs.py + +from dataclasses import dataclass +import itertools +import sys +import torch +import triton +from enum import Enum, auto +import math +from aiter.ops.triton.moe_routing.routing import RoutingData +from aiter.ops.triton._triton_kernels.moe_op_gemm_a8w4 import ( + _moe_gemm_a8w4, + _reduce_grouped, +) + + +# ----------------------------------------------------------------------------- +# Matrix Multiplication + Outer Gather/Scatter +# ----------------------------------------------------------------------------- + + +def can_overflow_int32(tensor: torch.Tensor): + max_int32 = (1 << 31) - 1 + offset = 0 + for i in range(tensor.ndim): + offset += (tensor.shape[i] - 1) * tensor.stride(i) + return offset > max_int32 + + +def should_upcast_indices(*args): + return any(tensor is not None and can_overflow_int32(tensor) for tensor in args) + + +def allocate_output( + x, + w, + out_dtype, + reduction_n_matmul, + reduction_n_reduction, + routing_data, + gather_indx, + scatter_indx, + block_m, + split_k, +): + # ---- output ------ + N = w.shape[-1] + # by default - M is number of rows in the activations + M = x.shape[-2] + # if the activations are gathered, then M is number of gather indices + if gather_indx is not None: + M = gather_indx.shape[0] + # final output + if routing_data.n_expts_act == 1 or scatter_indx is None: + y_rows = M + else: + y_rows = ( + scatter_indx.shape[0] // routing_data.n_expts_act + ) # compressed number of rows + matmul_shape = (split_k, M, N // reduction_n_matmul) + final_shape = (y_rows, N // reduction_n_matmul // reduction_n_reduction) + matmul_output = torch.empty(matmul_shape, device=x.device, dtype=out_dtype) + if scatter_indx is not None or split_k > 1: + final_output = torch.empty(final_shape, device=x.device, dtype=out_dtype) + else: + final_output = None + return matmul_output, final_output + + +def get_kernel_config(m, n, k, routing_data): + block_m = routing_data.block_m + group_m = 4 + num_xcds = 8 + xcd_swizzle = num_xcds + w_cache_modifier = ".cg" if block_m <= 32 else None + num_stages = 2 + + split_k = 1 + if block_m == 16: + block_n = 128 + block_k = 256 + num_warps = 4 + + grid_m = routing_data.n_blocks(m, block_m) + grid_n = triton.cdiv(n, block_n) + grid = grid_m * grid_n * split_k + while block_n >= 64 and grid < 256: + block_n = block_n // 2 + grid_m = routing_data.n_blocks(m, block_m) + grid_n = triton.cdiv(n, block_n) + grid = grid_m * grid_n * split_k + else: + # for scale preshuffling + block_n = 512 + block_k = 256 + num_warps = 8 + + ret = { + "block_m": block_m, + "block_n": block_n, + "block_k": block_k, + "num_warps": num_warps, + "num_stages": num_stages, + "group_m": group_m, + "xcd_swizzle": xcd_swizzle, + "w_cache_modifier": w_cache_modifier, + "split_k": split_k, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1, + } + return ret + + +def swizzle_scales(data): + NON_K_PRESHUFFLE_BLOCK_SIZE = 32 + block_shape = data.shape + SCALE_K = block_shape[-2] + N = block_shape[-1] + data = data.transpose(-1, -2) + data = data.view(-1, N // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, SCALE_K // 8, 2, 4, 1) + data = data.permute(0, 1, 4, 6, 3, 5, 2, 7).contiguous() + E = block_shape[0] + data = data.reshape(E, N // 32, SCALE_K * 32) + return data.transpose(-1, -2) + + +def reduce_grouped( + x: torch.Tensor, + indx: torch.Tensor, + out: torch.Tensor, + apply_swiglu=False, + alpha=1.0, + limit=1.0, + reduction_n=1, + out_dtype: bool = None, +): + """ + In-place grouped row reduction. + + Arguments + - x: Tensor[AnyFloat] of shape [(num_groups * K), N] + - indx: Tensor[Int] of shape [num_groups, K] + + Description + For each group g in [0, num_groups), this routine sums the K rows of `x` + specified by `indx[g, :]` and overwrites the row corresponding to the first + valid (non-negative) index with the per-group sum. Accumulation is performed + in float32 for numerical stability, and the result is written back in the + dtype of `x`. + + Behavior and edge cases + - Invalid (-1) entries are skipped during accumulation and do not generate + memory traffic. If a group has no valid entries, nothing is written for + that group. + - Reduction is performed tile-by-tile along the N dimension within a single + kernel launch (persistent along N) to minimize launch overhead. + + Performance notes + - Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x), + plus index reads. With no invalid entries, this becomes (K + 1) reads/writes + of length N per group. + + Returns + - The input tensor `x` (modified in place). + """ + if indx is None and x.shape[0] == 1: + return x.squeeze(0) + if indx is not None: + num_groups = indx.shape[0] + else: + num_groups = x.shape[-2] + K = 1 if indx is None else indx.shape[1] + out_dtype = x.dtype if out_dtype is None else out_dtype + assert x.shape[-1] % reduction_n == 0 + BLOCK_N = 512 + num_blocks = triton.cdiv(x.shape[-1], BLOCK_N) + + _reduce_grouped[(num_blocks, num_groups)]( + x, + x.stride(0), + x.stride(1), + x.stride(2), # + out, + out.stride(0), + out.stride(1), # + indx, # + x.shape[0], + x.shape[-1], # + apply_swiglu, + alpha, + limit, + reduction_n, + BLOCK_N=BLOCK_N, + EVEN_N=(x.shape[-1] % BLOCK_N == 0), + K=K, # + num_warps=2, # + ) + return out + + +# ----------------------------------------------------------------------------- +# Triton Implementation +# ----------------------------------------------------------------------------- + + +def moe_gemm_a8w4( + x, + w, + x_scales, + w_scales, + x_static_scale=None, + quant_static_scale=None, + bias=None, + routing_data: RoutingData | None = None, + gather_indx=None, + scatter_indx=None, + gammas=None, + swizzle_mx_scale=None, + out_dtype=torch.bfloat16, + apply_swiglu=False, + alpha=1.0, + limit=1.0, + unpadded_N=None, + unpadded_K=None, +): + """ + Y[:, :] = 0. + for e in num_experts: + Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :]) + """ + assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp" + x_has_mx = x_scales is not None + if x_has_mx: + assert x.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp" + if x_has_mx: + stride_x_mx_m = x_scales.stride(0) + stride_x_mx_k = x_scales.stride(1) + else: + stride_x_mx_m = 0 + stride_x_mx_k = 0 + # determine shapes + M = x.shape[-2] if gather_indx is None else gather_indx.shape[0] + K, N = x.shape[-1], w.shape[-1] + block_m = routing_data.block_m + if unpadded_N and block_m == 16: + N = unpadded_N + if unpadded_K and block_m == 16: + K = unpadded_K + # compute optimization flags + config = get_kernel_config(M, N, K, routing_data) + if apply_swiglu and config["split_k"] > 1: + apply_swiglu_matmul = False + reduction_n_matmul = 1 + apply_swiglu_reduction = True + reduction_n_reduction = 2 + elif apply_swiglu: + apply_swiglu_matmul = True + reduction_n_matmul = 2 + apply_swiglu_reduction = False + reduction_n_reduction = 1 + else: + apply_swiglu_matmul = False + reduction_n_matmul = 1 + apply_swiglu_reduction = False + reduction_n_reduction = 1 + # allocate output memory + y, y_final = allocate_output( + x, + w, + out_dtype, + reduction_n_matmul, + reduction_n_reduction, + routing_data, + gather_indx, + scatter_indx, + config["block_m"], + config["split_k"], + ) + stride_bias = None if bias is None else bias.stride(0) + # moe metadata + expt_data = routing_data.expt_data + expt_hist = None if expt_data is None else expt_data.hist + expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[-1] + expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw + expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map + # spmd grid + grid_m = routing_data.n_blocks(M, config["block_m"]) + grid_n = triton.cdiv(N, config["block_n"]) + grid = grid_m * grid_n * config["split_k"] + # launch kernel + _moe_gemm_a8w4[(grid,)]( + y, + y.stride(0), + y.stride(1), + y.stride(2), + x, + x.stride(0), + x.stride(1), + x_scales, + stride_x_mx_m, + stride_x_mx_k, + w, + w.stride(0), + w.stride(1), + w.stride(2), + w_scales, + w_scales.stride(0), + w_scales.stride(1), + w_scales.stride(2), + x_static_scale, + quant_static_scale, + bias, + stride_bias, + gammas, + N, + K, + gather_indx, + expt_hist, + expt_token_offs_raw, + expt_hist_sum, + expt_block_pid_map, + grid_m, + grid_n, + apply_swiglu_matmul, + alpha, + limit, + reduction_n_matmul, + routing_data.n_expts_act, + config["block_m"], + config["block_n"], + config["block_k"], + config["group_m"], + XCD_SWIZZLE=config["xcd_swizzle"], + SWIZZLE_MX_SCALE=swizzle_mx_scale, + SPLIT_K=config["split_k"], + EVEN_K=K % config["block_k"] == 0, + MASK_K_LIMIT=K % config["block_k"], + W_CACHE_MODIFIER=config["w_cache_modifier"], + num_warps=config["num_warps"], + num_stages=config["num_stages"], + UPCAST_INDICES=should_upcast_indices(x, w, y), + waves_per_eu=config["waves_per_eu"], + matrix_instr_nonkdim=config["matrix_instr_nonkdim"], + kpack=config["kpack"], + ) + # Build grouped reduction inputs in a uniform way + group_indx = ( + None + if scatter_indx is None + else scatter_indx.view(-1, routing_data.n_expts_act) + ) + y_final = reduce_grouped( + y, + group_indx, + y_final, + apply_swiglu_reduction, + alpha, + limit, + reduction_n_reduction, + out_dtype=out_dtype, + ) + return y_final + + +# ----------------------------------------------------------------------------- +# Reference Implementation +# ----------------------------------------------------------------------------- + + +def swiglu_torch(a, alpha, limit): + a_gelu = a[..., ::2] + if limit is not None: + a_gelu = a_gelu.clamp(max=limit) + a_linear = a[..., 1::2] + if limit is not None: + a_linear = a_linear.clamp(min=-limit, max=limit) + + out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu) + out = out_gelu * (a_linear + 1) + return out + + +def moe_gemm_torch( + x, + w, + bias, + routing_data: RoutingData = None, + gather_indx=None, + scatter_indx=None, + gammas=None, + apply_swiglu=False, + alpha=1.0, + limit=1.0, +): + assert x.dtype.itemsize > 1 + assert w.dtype.itemsize > 1 + if bias is not None and bias.ndim == 1: + bias = bias.view(1, *bias.shape) + if w.ndim == 2: + w = w.view(1, *w.shape) + n_expts_act = routing_data.n_expts_act + # memory offsets + if routing_data.n_expts_tot > 1: + sizes = routing_data.expt_hist + off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32) + off[1:] = torch.cumsum(sizes, 0) + offs = list(itertools.pairwise(off)) + else: + offs = [[0, x.shape[0]] for _ in range(w.shape[0])] + # compute + n_rows = x.shape[0] if gather_indx is None else gather_indx.shape[0] + n_cols = w.shape[-1] // 2 if apply_swiglu else w.shape[-1] + y = torch.zeros((n_rows, n_cols), device=x.device, dtype=x.dtype) + for i, (lo, hi) in enumerate(offs): + if gather_indx is None: + idx = torch.arange(lo, hi, device=x.device) + else: + idx = gather_indx[lo:hi] // n_expts_act + out = torch.matmul(x[idx, :].float(), w[i].float()) + if bias is not None: + out += bias[i, :] + if apply_swiglu: + out = swiglu_torch(out, alpha, limit) + if gammas is not None: + out *= gammas[lo:hi, None] + y[lo:hi, :] = out + if scatter_indx is None: + return y + # accumulate output from all experts + n_rows = y.shape[0] // n_expts_act + out = torch.zeros((n_rows, y.shape[-1]), dtype=torch.float32, device=x.device) + src_idx = scatter_indx.view(-1, n_expts_act) + for i in range(n_rows): + out[i, :] = y[src_idx[i], :].float().sum(0) + + return out diff --git a/aiter/ops/triton/moe_routing/bitmatrix.py b/aiter/ops/triton/moe_routing/bitmatrix.py new file mode 100644 index 0000000000..8a4c8e1bc4 --- /dev/null +++ b/aiter/ops/triton/moe_routing/bitmatrix.py @@ -0,0 +1,82 @@ +import torch +import triton +from typing import Type +from aiter.ops.triton._triton_kernels.moe_routing.bitmatrix import ( + _sum_bitmatrix_memset, + _sum_bitmatrix_rows, +) +from dataclasses import dataclass, fields + + +@dataclass +class Bitmatrix: + """ + Represents a boolean matrix in a packed format where each element occupies + a single bit of memory. + + _scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along + with the actual bitmatrix to avoid having to launch a separate memset + kernel when we call Bitmatrix::sum(). + """ + + scratchpad: torch.Tensor = None + + def __init__(self, data, shape, scratchpad=None, scratchpad_partials=None): + self.data = data + self.shape = shape + self.device = data.device + self.scratchpad = scratchpad + self.scratchpad_partials = scratchpad_partials + + def sum(self, partials_block_size): + _, n_cols = self.shape + dev = self.device + if self.scratchpad is None: + self.scratchpad = clear_sums(n_cols, dev) + out_ret = self.scratchpad[:n_cols] + self.scratchpad = None # throw error if we try to sum again + return sum_bitmatrix_rows(self, out_ret, partials_block_size) + + +def clear_sums(n_cols, device, MEMSET_BLOCK=512): + cdiv = triton.cdiv + blocks = cdiv(n_cols, MEMSET_BLOCK) + out_ret = torch.empty((blocks * MEMSET_BLOCK,), device=device, dtype=torch.int32) + _sum_bitmatrix_memset[(blocks,)](out_ret, MEMSET_BLOCK) + return out_ret + + +def sum_bitmatrix_rows(x, out_ret, partials_block_size=None): + assert partials_block_size is not None + cdiv = triton.cdiv + PARTIALS_BLOCK_M = partials_block_size + n_rows, n_cols = x.shape + assert out_ret.shape == (n_cols,) + + TILE_SIZE = 8 + BLOCK_MM = PARTIALS_BLOCK_M * TILE_SIZE + + pids_x = cdiv(n_rows, BLOCK_MM) + pids_y = cdiv(n_cols, 32) + out_partials = x.scratchpad_partials + + # output tensors + _sum_bitmatrix_rows[(pids_x, pids_y)]( + x.data, + n_rows, + x.data.stride(0), + x.data.stride(1), # input + out_ret, # output [final reduction] + out_partials, + out_partials.stride(0), + out_partials.stride(1), + out_partials.shape[1], + pids_x, # output [partial reductions] + BLOCK_M=PARTIALS_BLOCK_M, + BLOCK_MM=BLOCK_MM, # constants + num_warps=8, + ) + + out_partials = out_partials[: cdiv(n_rows, PARTIALS_BLOCK_M), :] + + return out_ret, out_partials diff --git a/aiter/ops/triton/moe_routing/routing.py b/aiter/ops/triton/moe_routing/routing.py new file mode 100644 index 0000000000..c42d22874b --- /dev/null +++ b/aiter/ops/triton/moe_routing/routing.py @@ -0,0 +1,289 @@ +import math +import torch +import triton +from dataclasses import dataclass, field +from aiter.ops.triton._triton_kernels.moe_routing.routing import _combined_routing + + +@dataclass +class ExptData: + # hist[i] is the number of tokens routed to expert i + hist: torch.Tensor + # token_offs_raw[i] is the offset of the first token routed + # to expert i in an expert-sorted array + token_offs_raw: torch.Tensor + # token_offs_pad[i] is the offset of the first token routed + # to expert i in an expert-sorted array, assuming histogram + # rounded to the next multiple of `block_m` + token_offs_pad: torch.Tensor + # block_id_map contain one value for each `pid`` launched by + # the matrix multiplication kernel launched with block_m: + # - the value is -1 if the `pid` has no work to do + # - otherwise, the value is two int16 (packed as an int32) that + # correspond respectively to (1) the expert assigned to + # the tokens processed by this pid; (2) the block assigned to the + # tokens processed by this pid (think `pid_m` in a regular matmul) + # see `test_routing.py` for a reference implementation and more details + block_pid_map: torch.Tensor + + def __post_init__(self): + if self.hist is not None: + assert self.hist.dtype == torch.int32 + if self.token_offs_raw is not None: + assert self.token_offs_raw.dtype == torch.int32 + if self.token_offs_pad is not None: + assert self.token_offs_pad.dtype == torch.int32 + if self.block_pid_map is not None: + assert self.block_pid_map.dtype == torch.int32 + + +@dataclass +class RoutingData: + block_m: int = field() + gate_scal: torch.Tensor = field() + expt_hist: torch.Tensor = field() + n_expts_tot: int = field() + n_expts_act: int = field() + expt_data: ExptData = None + + def n_blocks(self, n_rows, block_m): + if n_rows <= self.n_expts_tot: + return n_rows + else: + return ( + triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m) + + self.n_expts_tot + - 1 + ) + + +# -------------------------- +# sort tokens by expert +# -------------------------- + + +def sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M): + cdiv = triton.cdiv + + device = expt_scal.device + dtype = expt_scal.dtype + n_tokens, n_expts_act = expt_scal.shape + n_gates = n_tokens * n_expts_act + + hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M) + hist = hist[:n_expts_tot] + assert hist.dtype == torch.int32 + # scratchpad + combined_indx = torch.empty(n_gates * 2, dtype=torch.int32, device=device) + # output + topk_indx = combined_indx[:n_gates] + gate_indx = combined_indx[n_gates:] + gate_scal = torch.empty(n_gates, dtype=dtype, device=device) + + token_offs_raw, token_offs_pad, block_pid_map, blocks1a, BLOCK_A, block_m_log2 = ( + _compute_expt_data_internal(n_expts_tot, n_gates, block_m, device) + ) + + blocks1b = cdiv(n_tokens, HIST_BLOCK_M) + + indx_offs = partial_hist + + _combined_routing[(blocks1a + blocks1b,)]( + topk_indx, + gate_indx, + gate_scal, # outputs + expt_scal, + expt_indx, + indx_offs, + indx_offs.stride(0), + indx_offs.stride(1), # inputs + n_gates, # input shape + HIST_BLOCK_M, + n_tokens % HIST_BLOCK_M == 0, + n_expts_act, # constants + hist, + n_expts_tot, + token_offs_raw, + token_offs_pad, # + blocks1a, + block_pid_map, + block_pid_map.shape[0], # + block_m_log2, + BLOCK_A=BLOCK_A, + EQUAL_A=(hist.shape[0] == BLOCK_A), # optimization parameters + num_warps=1, + ) + + return ( + hist, + topk_indx, + gate_indx, + gate_scal, + token_offs_raw, + token_offs_pad, + block_pid_map, + ) + + +# -------------------------- +# expt_data +# -------------------------- + + +def log2_power_of_two(x): + assert x > 0 and (x & (x - 1)) == 0, "x must be a power of two" + return x.bit_length() - 1 + + +def _compute_expt_data_internal(n_expts_tot, n_gates, block_m, device): + BLOCK = 128 + cdiv = triton.cdiv + block_m_log2 = log2_power_of_two(block_m) + if n_gates <= n_expts_tot: + max_n_tiles = n_gates + else: + max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // block_m) + # allocate memory + pad = lambda x: cdiv(x, BLOCK) * BLOCK + dtype = torch.int32 + + token_offs_combined = torch.empty( + (2, pad(n_expts_tot + 1)), dtype=dtype, device=device + ) + + token_offs_raw = token_offs_combined[0][: n_expts_tot + 1] + token_offs_pad = token_offs_combined[1][: n_expts_tot + 1] + + # block_pid_map = torch.empty((pad(max_n_tiles),), dtype=dtype, device=device) + block_pid_map = torch.empty((max_n_tiles,), dtype=dtype, device=device) + # block_pid_map = block_pid_map[:max_n_tiles] + + blocks1 = n_expts_tot + return token_offs_raw, token_offs_pad, block_pid_map, blocks1, BLOCK, block_m_log2 + + +# -------------------------- +# routing +# -------------------------- + + +def routing(logits, n_expts_act, sm_first=False, expt_indx=None): + HIST_BLOCK_M = 32 + + from .topk import topk + + if sm_first: + logits = torch.softmax(logits, dim=-1) + expt_scal, expt_indx, bitmatrix = topk( + logits, + n_expts_act, + apply_softmax=not sm_first, + y_indx=expt_indx, + HIST_BLOCK_M=HIST_BLOCK_M, + ) + + num_tokens, n_expts_tot = logits.shape + m = num_tokens * n_expts_act + tokens_per_expt = max(1, m // n_expts_tot) + block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) + ( + hist, + topk_indx, + gate_indx, + gate_scal, + token_offs_raw, + token_offs_pad, + block_pid_map, + ) = sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix, block_m, HIST_BLOCK_M) + expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) + + # pack the matmul data structure + gather_indx = topk_indx + scatter_indx = gate_indx + return ( + RoutingData(block_m, gate_scal, hist, n_expts_tot, n_expts_act, expt_data), + gather_indx, + scatter_indx, + ) + + +# -------------------------- +# torch reference +# -------------------------- + + +def compute_expt_data_torch(hist, n_expts_tot, n_gates, block_m): + # offset for each experts + device = hist.device + token_offs_raw = torch.cumsum(hist, dim=0) + token_offs_raw = torch.cat((torch.zeros(1, device=device), token_offs_raw)) + token_offs_raw = token_offs_raw.int() + # maximum number of tiles for all values of `block_m` considered + if n_gates <= n_expts_tot: + max_n_tiles = n_gates + else: + # ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1 + # ceil_div(x, y): -(-x // y) + max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // block_m) + # fill up tile offset/infos for each block + n_tiles = (hist + block_m - 1) // block_m # matmul blocks needed + token_offs_pad = torch.cumsum(n_tiles, dim=0) + token_offs_pad = torch.cat((torch.zeros(1, device=device), token_offs_pad)) + token_offs_pad = token_offs_pad.int() + # compute data required to drive ragged batch matmul + block_pid_map = -torch.ones(max_n_tiles, device=device) + for e in range(n_expts_tot): + offset = token_offs_pad[e] + for b in range(n_tiles[e]): + block_pid_map[offset + b] = (b << 16) + e + block_pid_map = block_pid_map.int() + return ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) + + +def routing_torch(logits, n_expts_act, sm_first=False, expt_indx=None): + has_user_provided_indx = expt_indx is not None + n_gates_pad = logits.shape[0] * n_expts_act + + def topk(vals, k, expt_indx): + # topk of experts + if has_user_provided_indx: + tk_indx = expt_indx + else: + tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k] + tk_indx = tk_indx.long() + tk_val = torch.take_along_dim(vals, tk_indx, dim=1) + tk_indx = tk_indx.int() + return tk_val, tk_indx + + _, n_expts_tot = logits.shape + if sm_first: + logits = torch.softmax(logits, dim=-1) + expt_scal, expt_indx = topk(logits, n_expts_act, expt_indx) + if not sm_first: + expt_scal = torch.softmax(expt_scal, dim=-1) + # sort each token's selections by expert + if not has_user_provided_indx: + expt_indx, sort_indices = torch.sort(expt_indx, dim=1) + expt_scal = torch.gather(expt_scal, 1, sort_indices) + # flatten topk data + expt_scal = expt_scal.reshape(-1) + expt_indx = expt_indx.reshape(-1).to(torch.int32) + # sort by expert_id so experts are contiguous for the matmul + topk_indx = torch.argsort(expt_indx, stable=True) + gate_indx = torch.argsort(topk_indx, stable=True) + gate_scal = expt_scal[topk_indx] + hist = torch.histc( + expt_indx, bins=n_expts_tot, max=n_expts_tot - 1 + ).int() # histogram of tokens over experts + # pack the matmul data structure + gather_indx = topk_indx.int() + scatter_indx = gate_indx.int() + # compute expt_data + tokens_per_expt = max(1, n_gates_pad // n_expts_tot) + block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) + expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates_pad, block_m) + return ( + RoutingData(block_m, gate_scal, hist, n_expts_tot, n_expts_act, expt_data), + gather_indx, + scatter_indx, + ) diff --git a/aiter/ops/triton/moe_routing/topk.py b/aiter/ops/triton/moe_routing/topk.py new file mode 100644 index 0000000000..a8bbac5ba8 --- /dev/null +++ b/aiter/ops/triton/moe_routing/topk.py @@ -0,0 +1,84 @@ +import torch +import triton +from aiter.ops.triton._triton_kernels.moe_routing.topk import _topk +from aiter.ops.triton.moe_routing.bitmatrix import Bitmatrix + + +def topk( + x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, HIST_BLOCK_M=32 +): + x_shape = [x.shape[0], x.shape[1]] + cdiv = lambda a, b: (a + b - 1) // b + BLOCK_M = 32 + BLOCK_N = 128 + BLOCK_S = 128 + BLOCK_SP = 128 + assert len(x.shape) == 2 + assert x_shape[-1] < 32768 + assert dim == 1 + assert return_bitmatrix + n_rows, n_cols = x_shape + dev = x.device + # scratchpad tensors + # NOTE: these are not returned + y_vals = torch.empty((n_rows, k), dtype=x.dtype, device=dev) + if y_indx is not None: + use_provided_indx = True + else: + y_indx = torch.empty((n_rows, k), dtype=torch.int16, device=dev) + use_provided_indx = False + # create bitmatrix in transposed memory layout: + n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N + n_cols_words = n_cols_pad // 32 + bitmatrix = torch.empty( + (n_cols_words, cdiv(n_rows, 32) * 32), dtype=torch.uint32, device=dev + ) + bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows] + s_blocks = cdiv(n_cols, BLOCK_S) + s_cols = s_blocks * BLOCK_S + scratchpad = torch.empty((s_cols,), dtype=torch.int32, device=dev) + TILE_SIZE = 8 + BLOCK_MM = HIST_BLOCK_M * TILE_SIZE + pids_x = cdiv(n_rows, BLOCK_MM) + pids_y = cdiv(n_cols, 32) + scratchpad_partials = torch.empty( + (pids_y * 32, pids_x * TILE_SIZE), device=dev, dtype=torch.int32 + ) + scratchpad_partials = torch.transpose(scratchpad_partials, 0, 1) + sp_size = torch.numel(scratchpad_partials) + sp_blocks = cdiv(sp_size, BLOCK_SP) + pids = max(cdiv(n_rows, BLOCK_M), s_blocks + sp_blocks) + _topk[(pids,)]( + x, + x.stride(0), # inputs + y_vals, + y_indx, + y_vals.stride(0), + use_provided_indx, # output [topk] + bitmatrix, + bitmatrix.stride(0), + bitmatrix.stride(1), # output [bitmatrix] + n_rows, + n_cols, # shapes + scratchpad, + BLOCK_S, + s_blocks, # thing to memset to zero + scratchpad_partials, + BLOCK_SP, + sp_blocks, + sp_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, # tunable parameter + APPLY_SOFTMAX=apply_softmax, + N_EXPTS_PAD=n_cols_pad, + N_EXPTS_ACT=k, # constants + num_warps=8, + ) + bitmatrix_shape = [n_rows, n_cols_words * 32] + bitmatrix = Bitmatrix( + bitmatrix, + shape=bitmatrix_shape, + scratchpad=scratchpad, + scratchpad_partials=scratchpad_partials, + ) + return y_vals, y_indx, bitmatrix diff --git a/aiter/ops/triton/quant_moe.py b/aiter/ops/triton/quant_moe.py new file mode 100644 index 0000000000..f4dd62676b --- /dev/null +++ b/aiter/ops/triton/quant_moe.py @@ -0,0 +1,159 @@ +from enum import Enum +import triton +import torch +from aiter.ops.triton._triton_kernels.quant_moe import ( + _downcast_to_static_fp8, + _downcast_to_mxfp, + _upcast_from_mxfp, +) + + +def downcast_to_static_fp8(x: torch.Tensor, scale: torch.Tensor): + M, N = x.shape + y = torch.empty((M, N), dtype=torch.float8_e4m3fn, device="cuda") + + BLOCK_M = min(triton.next_power_of_2(M), 128) + if M <= 4096: + BLOCK_N = 32 + else: + BLOCK_N = 64 + grid_m = triton.cdiv(x.shape[0], BLOCK_M) + grid_n = triton.cdiv(x.shape[1], BLOCK_N) + + _downcast_to_static_fp8[(grid_m, grid_n)]( + x, + x.stride(0), + x.stride(1), + y, + y.stride(0), + y.stride(1), + scale, + M, + N, + BLOCK_M, + BLOCK_N, + num_warps=8, + ) + + return y + + +class DequantScaleRoundingMode(Enum): + ROUND_UP = 0 + ROUND_DOWN = 1 + + +def downcast_to_mxfp( + src_tensor: torch.Tensor, + out_quant_type: torch.dtype, + axis: int, + DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP, +): + """ + Convert the src weights to mx format. The src weight is quantized along the axis dimension. + + If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte. + Note that this means the k_dim of the tensor will be half of the logical k_dim. + + If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored + in their respective formats. + """ + ndim = src_tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + # downcast + src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1) + is_fp4 = out_quant_type == torch.uint8 + is_fp8 = out_quant_type in (torch.float8_e4m3fn, torch.float8_e5m2) + assert is_fp4 or is_fp8 + divisor = 2 if is_fp4 else 1 + L = src_tensor.shape[-1] + if is_fp4: + assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}" + out_shape = src_tensor.shape[:-1] + (L // divisor,) + out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, 32),) + + out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type) + out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8) + + kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1]) + kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1]) + kernel_scale = out_scale.view(-1, out_scale.shape[-1]) + + BLOCK_OUT_DIM = 128 + BLOCK_QUANT_DIM = 32 + grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM) + grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM) + + _downcast_to_mxfp[(grid_out, grid_quant)]( + kernel_quant_tensor, + *kernel_quant_tensor.stride(), + kernel_scale, + *kernel_scale.stride(), + kernel_src_tensor, + *kernel_src_tensor.stride(), + *kernel_src_tensor.shape, + BLOCK_OUT_DIM, + BLOCK_QUANT_DIM, + DEQUANT_SCALE_ROUNDING_MODE.value, + num_warps=8, + ) + + out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1) + out_scale = out_scale.transpose(axis, src_tensor.ndim - 1) + return out_quant_tensor, out_scale + + +def upcast_from_mxfp( + tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype, axis: int +): + """ + Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16. + + The function assumes that the tensors were quantized along the given axis. + It permutes the tensor so that the quantized axis is last, reshapes to 2D, + launches the Triton upcast kernel, and then unpermutes back to the original order. + """ + ndim = tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + assert tensor.ndim == scale.ndim, ( + f"Weight and scale must have the same number of dimensions. " + f"Got {tensor.ndim=} and {scale.ndim=}" + ) + # dtype checks + assert tensor.dtype in { + torch.uint8, + torch.float8_e5m2, + torch.float8_e4m3fn, + }, f"Invalid tensor dtype {tensor.dtype=}" + assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}" + assert dtype in (torch.float16, torch.bfloat16), f"Invalid output dtype {dtype=}" + # upcast + logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1) + tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous() + scale = scale.transpose(axis, scale.ndim - 1).contiguous() + out = torch.empty( + (*tensor.shape[:-1], logical_quant_dim), dtype=dtype, device=tensor.device + ) + reshaped_out = out.view(-1, out.shape[-1]) + reshaped_tensor = tensor.view(-1, tensor.shape[-1]) + reshaped_scale = scale.view(-1, scale.shape[-1]) + BLOCK_OUT_DIM = 128 + BLOCK_QUANT_DIM = 32 + blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM) + blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM) + _upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)]( + reshaped_out, + *reshaped_out.stride(), + reshaped_scale, + *reshaped_scale.stride(), + reshaped_tensor, + *reshaped_tensor.stride(), + *reshaped_out.shape, + BLOCK_OUT_DIM, + BLOCK_QUANT_DIM, + num_warps=8, + ) + out = out.transpose(axis, scale.ndim - 1).contiguous() + return out diff --git a/aiter/ops/triton/utils/_triton/pid_preprocessing.py b/aiter/ops/triton/utils/_triton/pid_preprocessing.py index e38caf7754..e3c2b47bbc 100644 --- a/aiter/ops/triton/utils/_triton/pid_preprocessing.py +++ b/aiter/ops/triton/utils/_triton/pid_preprocessing.py @@ -73,6 +73,7 @@ def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: tl.constexp group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + tl.assume(group_size_m >= 0) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m diff --git a/op_tests/op_benchmarks/triton/bench_moe_gemm_a8w4.py b/op_tests/op_benchmarks/triton/bench_moe_gemm_a8w4.py new file mode 100644 index 0000000000..36464e27e9 --- /dev/null +++ b/op_tests/op_benchmarks/triton/bench_moe_gemm_a8w4.py @@ -0,0 +1,320 @@ +# adapted from triton_kernels package +# original code https://github.com/triton-lang/triton/blob/main/python/triton_kernels/bench/bench_mlp.py + +from itertools import chain +from pathlib import Path +from copy import deepcopy +import csv +import triton.profiler as proton +import torch +import argparse +from aiter.ops.triton.moe_routing.routing import routing +from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 +from aiter.ops.triton.moe_op_gemm_a8w4 import ( + moe_gemm_a8w4, + swizzle_scales, +) +from aiter.ops.triton.utils._triton.arch_info import get_arch +import tempfile +from aiter.ops.triton.quant_moe import downcast_to_static_fp8, downcast_to_mxfp +import inspect + + +def parse_profile(profile_path, useful_op_regex, reps): + """ + construct a PerfRecord from a (proton) profile path and a regex for useful operations + """ + from triton.profiler import viewer + + gf, _, _, _ = viewer.read(profile_path) + # aggregate "useful" flops + bytes + useful = gf.filter( + f"MATCH ('*', c) WHERE c.'name' =~ '{useful_op_regex}' AND c IS LEAF" + ).dataframe + bytes = int(useful["bytes"].sum()) + flops = int( + sum(useful[[c for c in ["flops8", "flops16"] if c in useful.columns]].sum()) + ) + # take all ops (incl. "not useful" ones) when computing total time + allops = gf.filter("MATCH ('*', c) WHERE c IS LEAF").dataframe + total_time_ns = allops["time (ns)"].sum() + kernel_time_ns = useful["time (ns)"].sum() + return { + "total_time_ns": total_time_ns, + "kernel_time_ns": kernel_time_ns, + "flops": flops, + "bytes": bytes, + "reps": reps, + } + + +def compute_roofline( + *args, bench_fn, intensity_proxy_name, intensity_proxy_values, out_path, **kwargs +): + # validate input args + if not isinstance(intensity_proxy_name, str): + raise TypeError( + "intensity_proxy must be a string naming a parameter in target_fn" + ) + # determine position of intensity_proxy in target_fn signature + sig = inspect.signature(bench_fn) + params = list(sig.parameters.values()) + if intensity_proxy_name not in sig.parameters: + raise ValueError( + f"Parameter '{intensity_proxy_name}' not found in {bench_fn.__name__} signature" + ) + pos_index = [p.name for p in params].index(intensity_proxy_name) + + # wrapper to inject intensity proxy into target_fn and call it + def inject_proxy_and_call(val, args, kwargs): + args_list = list(args) + args_list.insert(pos_index, val) + return bench_fn(*args_list, **kwargs) + + # collect performance data + perfs = [] + print("=========================================") + print(f"{out_path }...") + print("=========================================") + for val in intensity_proxy_values: + perf = inject_proxy_and_call(val, args, kwargs) + perfs.append(perf) + tflops = perfs[-1]["flops"] / perfs[-1]["kernel_time_ns"] * 1e-3 + tbps = perfs[-1]["bytes"] / perfs[-1]["kernel_time_ns"] * 1e-3 + total_latency = perfs[-1]["total_time_ns"] / 1e3 / perfs[-1]["reps"] + kernel_latency = perfs[-1]["kernel_time_ns"] / 1e3 / perfs[-1]["reps"] + print( + f"{intensity_proxy_name}: {val:5d} | Total latency (us): {total_latency:.2f} | Kernel latency (us): {kernel_latency:.2f} | TFLOPS: {tflops:#.4g} | TBPS: {tbps:.2f}" + ) + + +def check_and_swizzle_scales(scale, N, K): + if N % 32 == 0 and K % (32 * 8) == 0: + scale = swizzle_scales(scale) + return scale, "CDNA4_SCALE" + else: + return scale, None + + +def quantize(x, dtype): + if dtype == "bf16": + x = x.to(torch.bfloat16).transpose(-1, -2).contiguous().transpose(-1, -2) + return x, None + elif dtype == "fp8": + scale = x.abs().max().item() / 448.0 + fp8e4_dtype = ( + torch.float8_e4m3fn if get_arch() != "gfx942" else torch.float8_e4m3fnuz + ) + x = x.to(fp8e4_dtype) + return x, scale + elif dtype == "mx8": + fp8e4_dtype = ( + torch.float8_e4m3fn if get_arch() != "gfx942" else torch.float8_e4m3fnuz + ) + x, scale = downcast_to_mxfp(x, fp8e4_dtype, axis=1) + return x, scale + else: + assert dtype == "mx4", f"{dtype=}" + x, scale = downcast_to_mxfp(x.to(torch.bfloat16), torch.uint8, axis=1) + return x, scale + + +def bench_mlp( + batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, op_regex +): + rank = 0 + dev = f"cuda:{rank}" + + assert dim2 % TP == 0, f"{dim2=}, {TP=}, dim2 must be divisible by TP" + + # -- init data -- + # weights + wg = torch.randn((dim1, n_expts_tot), device=dev) + w1 = torch.randn((n_expts_tot, dim1, dim2 // TP), device=dev) + w2 = torch.randn((n_expts_tot, dim2 // TP // 2, dim1), device=dev) + # biases + bg = torch.randn((n_expts_tot,), device=dev) + b1 = torch.randn((n_expts_tot, dim2 // TP), device=dev) + b2 = torch.randn((n_expts_tot, dim1), device=dev) + + # -- numerics -- + wg, _ = quantize(wg, "bf16") + w1, w1_scale = quantize(w1, w_dtype) + w2, w2_scale = quantize(w2, w_dtype) + w1_scale, swizzle_mx_scale1 = check_and_swizzle_scales(w1_scale, dim2 // TP, dim1) + w2_scale, swizzle_mx_scale2 = check_and_swizzle_scales( + w2_scale, dim1, dim2 // TP // 2 + ) + + # -- benchmark -- + x_dtype_str = x_dtype + x_dtype = torch.float8_e4m3fn + # special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz + if x_dtype == torch.float8_e4m3fn and get_arch() == "gfx942": + x_dtype = torch.float8_e4m3fnuz + + reps = 100 + x = torch.randn((batch, dim1), dtype=torch.bfloat16, device=dev) + xg = x + if x_dtype_str == "fp8": + static_scale = torch.tensor(1e-4, device=dev) + # run layer + fpath = Path(tempfile.mktemp()) + proton.start(str(fpath), hook="triton") + for i in range(reps): + logits = gemm_a16w16(xg, wg.T, bg) + rdata, gather_indx, scatter_indx = routing(logits, n_expts_act) + if x_dtype_str == "fp8": + x = downcast_to_static_fp8(x, static_scale) + x = moe_gemm_a8w4( + x, + w1, + None, + w1_scale, + static_scale, + static_scale, + b1, + rdata, + gather_indx=gather_indx, + swizzle_mx_scale=swizzle_mx_scale1, + out_dtype=x_dtype, + apply_swiglu=True, + ) + x = moe_gemm_a8w4( + x, + w2, + None, + w2_scale, + static_scale, + None, + b2, + rdata, + scatter_indx=scatter_indx, + swizzle_mx_scale=swizzle_mx_scale2, + ) + else: + assert x_dtype_str == "mx8" + x, _, x_scale = quantize(x, x_dtype_str) + x = moe_gemm_a8w4( + x, + w1, + x_scale, + w1_scale, + None, + None, + b1, + rdata, + gather_indx=gather_indx, + swizzle_mx_scale="CDNA4_SCALE", + apply_swiglu=True, + ) + x, _, x_scale = quantize(x, x_dtype_str) + x = moe_gemm_a8w4( + x, + w2, + x_scale, + w2_scale, + None, + None, + b2, + rdata, + scatter_indx=scatter_indx, + swizzle_mx_scale="CDNA4_SCALE", + ) + proton.finalize() + return parse_profile( + fpath.with_suffix(".hatchet"), useful_op_regex=op_regex, reps=reps + ) + + +def roofline_mlp( + batch_sizes, + dim1, + dim2, + n_expts_tot, + n_expts_act, + x_dtype, + w_dtype, + TP, + op_regex, + name="", +): + out_path = Path(f"logs/{name}/{x_dtype}x-{w_dtype}w-TP{TP}/") + out_path.mkdir(parents=True, exist_ok=True) + csv_path = compute_roofline( + dim1, + dim2, + n_expts_tot, + n_expts_act, + x_dtype, + w_dtype, + TP, + op_regex, # fixed args + bench_fn=bench_mlp, # function to benchmark + intensity_proxy_name="batch", # intensity proxy name + intensity_proxy_values=batch_sizes, # intensity proxy values to sweep + out_path=out_path.with_suffix(".csv"), + ) # output path + + +def parse_args(): + parser = argparse.ArgumentParser(prog="Benchmark MoE") + parser.add_argument( + "--shape", + type=int, + nargs="+", + metavar=("DIM"), + help="Input feature dimensions of MoE layers. Must be two integers.", + ) + parser.add_argument( + "--experts", + type=int, + nargs="+", + metavar=("DIM"), + help="Number of total and active experts in [total experts, active experts] order.", + ) + parser.add_argument( + "--op-regex", + type=str, + default=".*moe_gemm.*", + help="Regex to find perf for specific operation by its kernel name.", + ) + parser.add_argument( + "--act-dtype", + type=str, + default="fp8", + help="Activation dtype, fp8 or mx8.", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + + dim1, dim2 = args.shape + total_experts, active_experts = args.experts + batch_ranges_moe = [ + (1, 2, 1), + (2, 5, 2), + (8, 18, 8), + (32, 65, 32), + (128, 257, 128), + (1024, 1200, 200), + (4096, 8200, 4096), + ] + batch_sizes_moe = list(chain(*[range(*r) for r in batch_ranges_moe])) + quantized_dtypes = [args.act_dtype, "mx4"] + + roofline_mlp( + batch_sizes_moe, + dim1, + dim2, + total_experts, + active_experts, + quantized_dtypes[0], + quantized_dtypes[1], + TP=1, + op_regex=args.op_regex, + name="gpt-oss-x2", + ) diff --git a/op_tests/triton_tests/test_moe_gemm_a8w4.py b/op_tests/triton_tests/test_moe_gemm_a8w4.py new file mode 100644 index 0000000000..b8e3fa71a3 --- /dev/null +++ b/op_tests/triton_tests/test_moe_gemm_a8w4.py @@ -0,0 +1,325 @@ +# adapted from triton_kernels package +# original code https://github.com/triton-lang/triton/blob/main/python/triton_kernels/tests/test_matmul.py + +from dataclasses import dataclass, fields +import itertools +import pytest +import torch +from typing import Union +import triton + +# routing utilities +from aiter.ops.triton.moe_routing.routing import routing + +# matmul utilities +from aiter.ops.triton.moe_op_gemm_a8w4 import ( + moe_gemm_a8w4, + moe_gemm_torch, + swizzle_scales, +) + +# numerics utilities +from aiter.ops.triton.quant_moe import ( + downcast_to_static_fp8, + downcast_to_mxfp, + upcast_from_mxfp, +) + +# target-specific utilities +from aiter.ops.triton.utils._triton.arch_info import get_arch + +# --------------- +# initialize data +# --------------- + + +def alloc_rand(shape, device, dtype): + if dtype.itemsize == 1: + tmp = 2 ** -(torch.randint(4, 8, shape, device=device, dtype=torch.bfloat16)) + return tmp + return torch.randn(shape, device=device, dtype=dtype) + + +def alloc_rand_like(x): + return alloc_rand(x.shape, x.device, x.dtype) + + +def init_routing_data( + m, n_expts_tot, n_expts_act, do_gather, do_scatter, device="cuda" +): + logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=device) + routing_data, gather_idx, scatter_idx = routing(logits, n_expts_act) + routing_data.gate_scal = None + gather_idx = gather_idx if do_gather else None + scatter_idx = scatter_idx if do_scatter else None + # TODO: re-enable + # if do_gather and do_scatter and n_expts_act == 1 and n_expt_shards == 1: + # scatter_idx = mask_indx(scatter_idx, n_expts_act) + return m, routing_data, gather_idx, scatter_idx + + +def init_compute_data( + m, + n, + k, + gindx, + sindx, + n_expts_tot, + n_expts_act, + act_dtype, + weight_dtype, + has_y_gammas, + device="cuda", +): + torch.manual_seed(0) + in_m = m * (n_expts_act if gindx is None else 1) + shape_x = (in_m, k) + x = alloc_rand(shape_x, device=device, dtype=act_dtype) + w = alloc_rand((n_expts_tot, k, n), device=device, dtype=weight_dtype) + bias = alloc_rand((n_expts_tot, n), device=device, dtype=torch.float32) + if has_y_gammas: + gamma = 2 ** torch.randint( + -5, 0, (m * n_expts_act,), device=device, dtype=torch.float32 + ) + else: + gamma = None + return x, w, bias, gamma + + +def dtype_str_to_torch(dtype_str: str) -> torch.dtype: + return torch.uint8 if dtype_str == "float4_e2m1" else getattr(torch, dtype_str) + + +def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True): + if tri.dtype.itemsize == 1: + ref_as_type = ref.to(tri.dtype) + if ref.dtype == tri.dtype: + assert torch.all(ref_as_type == tri) + return + ref = ref_as_type + + if ref.numel() == 0: + return + + if maxtol is None: + maxtol = 2e-2 + if rmstol is None: + rmstol = 4e-3 + """ + Compare reference values against obtained values. + """ + + # cast to float32: + ref = ref.to(torch.float32).detach() + tri = tri.to(torch.float32).detach() + assert ( + ref.shape == tri.shape + ), f"Tensors must have same size {ref.shape=} {tri.shape=}" + + # deal with infinite elements: + inf_mask_ref = torch.isinf(ref) + inf_mask_tri = torch.isinf(tri) + assert torch.equal( + inf_mask_ref, inf_mask_tri + ), "Tensor must have same infinite elements" + refn = torch.where(inf_mask_ref, 0, ref) + trin = torch.where(inf_mask_tri, 0, tri) + + # normalise so that RMS calculation doesn't overflow: + eps = 1.0e-30 + multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps) + refn *= multiplier + trin *= multiplier + + ref_rms = torch.sqrt(torch.square(refn).mean()) + eps + + rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn)) + max_err = torch.max(rel_err).item() + rms_err = torch.sqrt(torch.square(rel_err).mean()).item() + + if verbose: + print( + "%s maximum relative error = %s (threshold = %s)" + % (description, max_err, maxtol) + ) + print( + "%s RMS relative error = %s (threshold = %s)" + % (description, rms_err, rmstol) + ) + + if max_err > maxtol: + bad_idxs = torch.nonzero(rel_err > maxtol) + num_nonzero = bad_idxs.size(0) + bad_idxs = bad_idxs[:1000] + print( + "%d / %d mismatched elements (shape = %s) at coords %s" + % (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist()) + ) + + bad_idxs = bad_idxs.unbind(-1) + print("ref values: ", ref[tuple(bad_idxs)].cpu()) + print("tri values: ", tri[tuple(bad_idxs)].cpu()) + + assert max_err <= maxtol + assert rms_err <= rmstol + + +# --------------- +# unit tests +# --------------- + + +@dataclass +class Case: + m: int + n: int + k: int + act_dtype_str: str + n_expts_tot: int = 1 + n_expts_act: int = 1 + hbm_swizzling: bool = False + + +@pytest.mark.parametrize( + ", ".join(f.name for f in fields(Case)), + [ + tuple(getattr(case, f.name) for f in fields(Case)) + for case in [ + Case(32, 6144, 3072, "float8_e4m3fn", 128, 4, hbm_swizzling=True), + Case(8192, 3072, 3072, "float8_e4m3fn", 128, 4, hbm_swizzling=True), + Case(4, 1024, 3072, "float8_e4m3fn", 128, 4, hbm_swizzling=True), + Case(1024, 3072, 512, "float8_e4m3fn", 128, 4, hbm_swizzling=True), + Case(4096, 3072, 3072, "float8_e4m3fn", 128, 4), + Case(16, 1024, 1024, "mxfloat8_e4m3fn", 128, 4, hbm_swizzling=True), + Case(4096, 1024, 1024, "mxfloat8_e4m3fn", 128, 4), + Case(16, 256, 256, "mxfloat8_e4m3fn", 128, 4, hbm_swizzling=True), + Case(4096, 256, 256, "mxfloat8_e4m3fn", 128, 4), + Case(1000, 704, 800, "mxfloat8_e4m3fn", 8, 2), + Case(300, 400, 800, "mxfloat8_e4m3fn", 8, 4), + ] + ], +) +@pytest.mark.parametrize( + "do_gather, do_scatter", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +@pytest.mark.parametrize("has_y_gammas", [False, True]) +@pytest.mark.parametrize("apply_swiglu", [False, True]) +@pytest.mark.parametrize("fused_quant", [False, True]) +def test_op( + m, + n, + k, + do_gather, + do_scatter, + has_y_gammas, + apply_swiglu, + fused_quant, + n_expts_tot, + n_expts_act, + act_dtype_str, + hbm_swizzling, + device="cuda", +): + + if get_arch() != "gfx950": + pytest.skip("float8 x mx only supported on CDNA4") + + if "float8_e4m3fnuz" in act_dtype_str and get_arch() != "gfx942": + pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform") + + if hbm_swizzling: + if get_arch() != "gfx950": + pytest.skip( + "Scale preshuffling on AMD GPU has not been emulated on non-CDNA4 arch yet." + ) + if n % 32 != 0 or k % (32 * 8) != 0: + pytest.skip( + f"Shape {m}x{n}x{k} is not supported for scale swizzling on AMD GPU" + ) + + torch.manual_seed(0) + + weight_dtype_str = "mxfloat4_e2m1" + weight_mxfp = weight_dtype_str.startswith("mx") + if weight_mxfp: + weight_dtype_str = weight_dtype_str[2:] + act_mxfp8 = act_dtype_str.startswith("mx") + if act_mxfp8: + act_dtype_str = act_dtype_str[2:] + + weight_dtype = dtype_str_to_torch(weight_dtype_str) + act_dtype = dtype_str_to_torch(act_dtype_str) + m, rdata, gindx, sindx = init_routing_data( + m, n_expts_tot, n_expts_act, do_gather, do_scatter, device=device + ) + x_tri, w_tri, bias_tri, gammas = init_compute_data( + m, + n, + k, + gindx, + sindx, + n_expts_tot, + n_expts_act, + torch.bfloat16 if act_mxfp8 else act_dtype, + torch.bfloat16, + has_y_gammas, + device=device, + ) + x_ref, w_ref, bias_ref = x_tri.clone(), w_tri.clone(), bias_tri.clone() + + # downcast to mxfp + w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=1) + w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=1) + if hbm_swizzling: + swizzle_mx_scale = "CDNA4_SCALE" + w_scale_tri = swizzle_scales(w_scale_tri) + else: + swizzle_mx_scale = None + + if act_mxfp8: + x_tri, x_mx_scales_tri = downcast_to_mxfp(x_tri, act_dtype, axis=-1) + x_ref = upcast_from_mxfp(x_tri, x_mx_scales_tri, torch.bfloat16, axis=-1) + x_static_scale = None + out_dtype = torch.bfloat16 + maxtol = None + rmstol = None + else: + x_mx_scales_tri = None + x_static_scale = x_tri.abs().max().float() / 448.0 + x_tri = downcast_to_static_fp8(x_tri, x_static_scale) + out_dtype = torch.float8_e4m3fn + maxtol = 4e-1 + rmstol = 4e-2 + + ref_y = moe_gemm_torch( + x_ref, w_ref, bias_ref, rdata, gindx, sindx, gammas, apply_swiglu + ) + if not act_mxfp8 and fused_quant: + quant_static_scale = ref_y.abs().max().float() / 448.0 + else: + quant_static_scale = None + tri_y = moe_gemm_a8w4( + x_tri, + w_tri, + x_mx_scales_tri, + w_scale_tri, + x_static_scale, + quant_static_scale, + bias_tri, + rdata, + gindx, + sindx, + gammas, + swizzle_mx_scale, + out_dtype, + apply_swiglu, + ) + if not act_mxfp8 and fused_quant: + tri_y = (tri_y.float() * quant_static_scale).to(ref_y.dtype) + assert_close(ref_y, tri_y, maxtol=maxtol, rmstol=rmstol) diff --git a/op_tests/triton_tests/test_moe_routing.py b/op_tests/triton_tests/test_moe_routing.py new file mode 100644 index 0000000000..85477f0dda --- /dev/null +++ b/op_tests/triton_tests/test_moe_routing.py @@ -0,0 +1,168 @@ +import pytest +import torch +from aiter.ops.triton.moe_routing.routing import routing, routing_torch +from aiter.ops.triton.utils._triton.arch_info import get_arch + + +def assert_equal(ref, tri): + if isinstance(ref, torch.Tensor): + assert torch.all(ref == tri) + else: + assert ref == tri + + +def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True): + if tri.dtype.itemsize == 1: + ref_as_type = ref.to(tri.dtype) + if ref.dtype == tri.dtype: + assert torch.all(ref_as_type == tri) + return + ref = ref_as_type + + if maxtol is None: + maxtol = 2e-2 + if rmstol is None: + rmstol = 4e-3 + """ + Compare reference values against obtained values. + """ + + # cast to float32: + ref = ref.to(torch.float32).detach() + tri = tri.to(torch.float32).detach() + assert ( + ref.shape == tri.shape + ), f"Tensors must have same size {ref.shape=} {tri.shape=}" + + # deal with infinite elements: + inf_mask_ref = torch.isinf(ref) + inf_mask_tri = torch.isinf(tri) + assert torch.equal( + inf_mask_ref, inf_mask_tri + ), "Tensor must have same infinite elements" + refn = torch.where(inf_mask_ref, 0, ref) + trin = torch.where(inf_mask_tri, 0, tri) + + # normalise so that RMS calculation doesn't overflow: + eps = 1.0e-30 + multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps) + refn *= multiplier + trin *= multiplier + + ref_rms = torch.sqrt(torch.square(refn).mean()) + eps + + rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn)) + max_err = torch.max(rel_err).item() + rms_err = torch.sqrt(torch.square(rel_err).mean()).item() + + if verbose: + print( + "%s maximum relative error = %s (threshold = %s)" + % (description, max_err, maxtol) + ) + print( + "%s RMS relative error = %s (threshold = %s)" + % (description, rms_err, rmstol) + ) + + if max_err > maxtol: + bad_idxs = torch.nonzero(rel_err > maxtol) + num_nonzero = bad_idxs.size(0) + bad_idxs = bad_idxs[:1000] + print( + "%d / %d mismatched elements (shape = %s) at coords %s" + % (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist()) + ) + + bad_idxs = bad_idxs.unbind(-1) + print("ref values: ", ref[tuple(bad_idxs)].cpu()) + print("tri values: ", tri[tuple(bad_idxs)].cpu()) + + assert max_err <= maxtol + assert rms_err <= rmstol + + +def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"): + logits = torch.randn((n_tokens, n_expts_tot), dtype=dtype, device=device) + return logits + + +n_tokens = [4, 7, 8, 64, 255, 256, 371, 911, 1023, 1024, 4096, 8192] + + +@pytest.mark.parametrize("n_tokens", n_tokens) +@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 4), (128, 32), (1500, 8)]) +@pytest.mark.parametrize("use_expt_indx", [False, True]) +@pytest.mark.parametrize("sm_first", [True, False]) +def test_op(n_tokens, n_expts_tot, n_expts_act, sm_first, use_expt_indx): + if get_arch() != "gfx950": + pytest.skip("MOE stack not fully implemented on non-CDNA4 arch yet.") + + device = "cuda" + torch.manual_seed(2) + n_gates_raw = n_tokens * n_expts_act + tri_logits = init_data( + n_tokens, n_expts_tot, device=device, dtype=torch.float32 + ).detach() + tri_logits[n_tokens:, :] = float("inf") # should not be used + ref_logits = tri_logits.clone().detach() + + if use_expt_indx: + rand_idx = lambda: torch.randperm(n_expts_tot, device="cuda", dtype=torch.int64) + tri_expt_indx = torch.stack([rand_idx()[:n_expts_act] for _ in range(n_tokens)]) + tri_expt_indx, _ = torch.sort(tri_expt_indx, dim=1) + tri_expt_indx[n_tokens:] = -99999 # should not be used + ref_expt_indx = tri_expt_indx[:n_tokens] + else: + tri_expt_indx = ref_expt_indx = None + ref_routing_data, ref_gather, ref_scatter = routing_torch( + ref_logits, n_expts_act, sm_first, ref_expt_indx + ) + tri_routing_data, tri_gather, tri_scatter = routing( + tri_logits, n_expts_act, sm_first, tri_expt_indx + ) + + def _assert_indx_equal(ref, tri): + assert_equal(ref, tri[: len(ref)]) + assert torch.all(tri[len(ref) :] == -1) + + assert_close( + ref_routing_data.gate_scal, tri_routing_data.gate_scal[:n_gates_raw], 2e-2, 4e-3 + ) + assert_equal(ref_routing_data.expt_hist, tri_routing_data.expt_hist) + + ref_expt_data = ref_routing_data.expt_data + tri_expt_data = tri_routing_data.expt_data + assert_equal(ref_expt_data.hist, tri_expt_data.hist) + assert_equal(ref_expt_data.token_offs_raw, tri_expt_data.token_offs_raw) + assert_equal(ref_expt_data.token_offs_pad, tri_expt_data.token_offs_pad) + assert_equal(ref_expt_data.block_pid_map, tri_expt_data.block_pid_map) + + assert ref_routing_data.n_expts_tot == ref_routing_data.n_expts_tot + assert ref_routing_data.n_expts_act == ref_routing_data.n_expts_act + + _assert_indx_equal(ref_gather, tri_gather) + _assert_indx_equal(ref_scatter, tri_scatter) + + +def bench_routing(): + import triton.profiler as proton + + n_tokens = 8192 + n_expts_tot, n_expts_act = 128, 4 + tri_logits = init_data(n_tokens, n_expts_tot) + proton.start("routing") + proton.activate() + for i in range(100): + tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act) + proton.finalize() + try: + import os + + os.system("proton-viewer -m time/ms routing.hatchet") + except Exception: + pass + + +if __name__ == "__main__": + bench_routing()