diff --git a/aiter/ops/triton/_gluon_kernels/moe/moe_op_gemm_a4w4.py b/aiter/ops/triton/_gluon_kernels/moe/moe_op_gemm_a4w4.py new file mode 100644 index 0000000000..3f4a1d4302 --- /dev/null +++ b/aiter/ops/triton/_gluon_kernels/moe/moe_op_gemm_a4w4.py @@ -0,0 +1,631 @@ +import torch +from triton.experimental import gluon +import triton.experimental.gluon.language as gl +from aiter.ops.triton._triton_kernels.moe.activations import _swiglu + + +def matmul_launch_metadata(grid, kernel, args): + ret = dict() + M, N, K = None, args["N"], args["K"] + Y, X, W = args["Y"], args["X"], args["W"] + hist = args["ExptHist"] + if hist is not None: + n_rows = int(hist.float().mean()) + n_tokens = float(hist.sum()) + n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum() + else: + n_tokens = None + n_w_bytes = W.numel() * W.element_size() + + def repr(s, x): + return f"{s}={x}" if x is not None else f"E_{len(hist)}({s})={n_rows}" + + nbits = X.dtype.itemsize * 8 + ret["name"] = f"{kernel.name} [{repr('M', M)}, {repr('N', N)}, {repr('K', K)}]" + gindx = args.get("GatherIndx", None) + # sindx = args.get("WriteBackIndx", None) + if gindx is not None: + ret["name"] += "_layer1" + else: + ret["name"] += "_layer2" + if args["B"] is not None: + ret["name"] += "_bias" + if args["APPLY_SWIGLU"]: + ret["name"] += "_swiglu" + + fM = n_tokens + fK = K if K is not None else n_tokens + ret[f"flops{nbits}"] = 2.0 * fM * N * fK + + gindx = args.get("GatherIndx", None) + # sindx = args.get("WriteBackIndx", None) + n_x_bytes = X.numel() * X.element_size() + n_y_bytes = Y.numel() * Y.element_size() + if hist is not None: + assert n_tokens is not None + n_expts_act = args["N_EXPTS_ACT"] + + if gindx is not None: + # recreate inverse GatherIndx. + dst = torch.full_like(gindx, -1) + idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32) + mask = gindx != -1 + dst[gindx[mask]] = idx[mask] + n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum() + else: + n_read_rows = n_tokens + n_x_bytes = n_read_rows * X.shape[-1] * X.element_size() + n_y_bytes = n_tokens * Y.shape[-1] * Y.element_size() + ret["bytes"] = int(n_x_bytes + n_y_bytes + n_w_bytes) + + return ret + + +@gluon.jit +def pid_grid(pid: int, num_pid_m: int, num_pid_n: int, GROUP_SIZE_M: gl.constexpr = 1): + """ + Maps 1D pid to 2D grid coords (pid_m, pid_n). + + Args: + - pid: 1D pid + - num_pid_m: grid m size + - num_pid_n: grid n size + - GROUP_SIZE_M: default is 1 + """ + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + gl.assume(group_size_m >= 0) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + return pid_m, pid_n + + +# @gluon.jit(launch_metadata=matmul_launch_metadata, loop_carried_load_percent=0) +@gluon.jit(launch_metadata=matmul_launch_metadata) +def _moe_gemm_a4w4_gfx1250( + Y, + stride_y_k, + stride_y_m, + stride_y_n, + X, + stride_x_m, + stride_x_k, + XMxScale, + stride_x_mx_m, + stride_x_mx_k, + W, + stride_w_e, + stride_w_k, + stride_w_n, + WMxScale, + stride_w_mx_e, + stride_w_mx_k, + stride_w_mx_n, + # bias + B, + stride_b_e, + Gammas, + # shapes + num_tokens, + N, + K, + # expt data + GatherIndx, + ExptHist, + ExptOffs, + ExptOffsSum, + ExptData, + # true grid size + grid_m, + grid_n, + # fused activation function + APPLY_SWIGLU: gl.constexpr, + alpha, + limit, + ACTIVATION_REDUCTION_N: gl.constexpr, + ADD_RESIDUAL: gl.constexpr, + # MoE config + N_EXPTS_ACT: gl.constexpr, + # optimization config + BLOCK_M: gl.constexpr, + BLOCK_N: gl.constexpr, + BLOCK_K: gl.constexpr, + XCD_SWIZZLE: gl.constexpr, + SPLIT_K: gl.constexpr, + SWIZZLE_MX_SCALE: gl.constexpr, # "GFX1250_SCALE" | None + NUM_BUFFERS: gl.constexpr, + UPCAST_INDICES: gl.constexpr, + # layouts + WMMA_LAYOUT: gl.constexpr, + WMMA_LAYOUT_PACKED: gl.constexpr, + # triton configs + NUM_WARPS: gl.constexpr, +): + gl.assume(stride_y_k >= 0) + gl.assume(stride_y_m >= 0) + gl.assume(stride_y_n >= 0) + gl.assume(stride_x_m >= 0) + gl.assume(stride_x_k >= 0) + gl.assume(stride_w_e >= 0) + gl.assume(stride_w_k >= 0) + gl.assume(stride_w_n >= 0) + gl.assume(stride_x_mx_m >= 0) + gl.assume(stride_x_mx_k >= 0) + gl.assume(stride_w_mx_e >= 0) + gl.assume(stride_w_mx_k >= 0) + gl.assume(stride_w_mx_n >= 0) + if B is not None: + gl.assume(stride_b_e >= 0) + gl.assume(grid_m >= 0) + gl.assume(grid_n >= 0) + + MX_PACK_DIVISOR: gl.constexpr = 32 + gl.static_assert( + BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR" + ) + + NUM_LOADS_IN_BATCH: gl.constexpr = 2 + gl.static_assert(NUM_BUFFERS >= 3, "NUM_BUFFERS must be at least 3") + + w_type: gl.constexpr = W.dtype.element_ty + gl.static_assert(w_type == gl.uint8, "mx_weight_ptr must be uint8 or fp8") + gl.static_assert( + WMxScale.dtype.element_ty == gl.uint8, "mx_scale_ptr must be uint8" + ) + x_type: gl.constexpr = X.dtype.element_ty + gl.static_assert(x_type == gl.uint8, "mx_act_ptr must be uint8") + gl.static_assert( + XMxScale.dtype.element_ty == gl.uint8, "mx_scale_ptr must be uint8" + ) + + OUT_BLOCK_N: gl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N + yN = N // ACTIVATION_REDUCTION_N + + # get program id + pid = gl.program_id(0) + if ExptOffsSum is not None and XCD_SWIZZLE > 1: + # Determine how much padding there is on the expert data. This allows us to + # know the true grid size and avoid processing padding tiles. + padding_m = grid_m - gl.load(ExptOffsSum) + else: + padding_m: gl.constexpr = 0 + + index_type: gl.constexpr = gl.int64 if UPCAST_INDICES else gl.int32 + + # get unpadded grid size + unpadded_m = grid_m - padding_m + gl.assume(unpadded_m >= 0) + total_actual_tiles = unpadded_m * grid_n + if padding_m > 0 and pid >= total_actual_tiles: + return + + # swizzle program ids + pid_mn = pid % (unpadded_m * grid_n) + pid_m, pid_n = pid_grid(pid_mn, unpadded_m, grid_n, 1) + + # unpack expert data + expt_data = gl.load(ExptData + pid_m) + if expt_data == -1: + return + expt_id = expt_data & 0x0000FFFF + block_id = expt_data >> 16 + M = gl.load(ExptHist + expt_id) + start_m = gl.load(ExptOffs + expt_id) + expt_id, block_id = expt_id.to(index_type), block_id.to(index_type) + start_m = start_m.to(index_type) + pid_n = pid_n.to(index_type) + + # get the packed block sizes + # both A and B tensors are mxfp4 + # 2 MXFP4 elements are packed into 1 int8 + # in the K dimension + X_M_DIVISOR: gl.constexpr = 1 + X_K_DIVISOR: gl.constexpr = 2 # 2 MXFP4 elements packed into 1 byte + W_K_DIVISOR: gl.constexpr = 2 # 2 MXFP4 elements packed into 1 byte + W_N_DIVISOR: gl.constexpr = 1 + PACKED_BLOCK_M_X: gl.constexpr = BLOCK_M // X_M_DIVISOR + PACKED_BLOCK_K_X: gl.constexpr = BLOCK_K // X_K_DIVISOR + PACKED_BLOCK_K_W: gl.constexpr = BLOCK_K // W_K_DIVISOR + PACKED_BLOCK_N_W: gl.constexpr = BLOCK_N // W_N_DIVISOR + MX_SCALE_BLOCK_K: gl.constexpr = ( + BLOCK_K // MX_PACK_DIVISOR + ) # 32 elements share 1 scale element + + # wmma layouts + DOT_LAYOUT_X: gl.constexpr = gl.DotOperandLayout( + operand_index=0, parent=WMMA_LAYOUT_PACKED, k_width=16 + ) + DOT_LAYOUT_W: gl.constexpr = gl.DotOperandLayout( + operand_index=1, parent=WMMA_LAYOUT_PACKED, k_width=16 + ) + DOT_LAYOUT_X_SCALES: gl.constexpr = gl.amd.gfx1250.get_wmma_scale_layout( + DOT_LAYOUT_X, [PACKED_BLOCK_M_X, MX_SCALE_BLOCK_K] + ) + DOT_LAYOUT_W_SCALES: gl.constexpr = gl.amd.gfx1250.get_wmma_scale_layout( + DOT_LAYOUT_W, [PACKED_BLOCK_N_W, MX_SCALE_BLOCK_K] + ) + + # A pointers + offs_x_m = PACKED_BLOCK_M_X * block_id + if GatherIndx is None: + X += start_m * stride_x_m + else: + IDX_LAYOUT: gl.constexpr = gl.SliceLayout( + 0, gl.BlockedLayout([1, 8], [32, 1], [1, NUM_WARPS], [0, 1]) + ) + offs_x_m = PACKED_BLOCK_M_X * block_id + gl.arange( + 0, PACKED_BLOCK_M_X, layout=IDX_LAYOUT + ) + GatherIndx += start_m + offs_x_m = gl.amd.gfx1250.buffer_load(GatherIndx, offs_x_m) // N_EXPTS_ACT + + # B pointers + offs_w_n = pid_n * PACKED_BLOCK_N_W + W += expt_id * stride_w_e + + # A scale pointers + if GatherIndx is None: + XMxScale += start_m * stride_x_mx_m + + # B scale pointers + WMxScale += expt_id * stride_w_mx_e + if SWIZZLE_MX_SCALE == "GFX1250_SCALE": + SCALE_KWIDTH: gl.constexpr = 4 if MX_SCALE_BLOCK_K >= 4 else MX_SCALE_BLOCK_K + PRESHUFFLE_FACTOR: gl.constexpr = 128 + PACKED_MX_BLOCK: gl.constexpr = MX_SCALE_BLOCK_K * PRESHUFFLE_FACTOR + SCALE_BLOCK_N: gl.constexpr = BLOCK_N // PRESHUFFLE_FACTOR + else: + PRESHUFFLE_FACTOR: gl.constexpr = 1 + PACKED_MX_BLOCK: gl.constexpr = MX_SCALE_BLOCK_K + SCALE_BLOCK_N: gl.constexpr = BLOCK_N + offs_w_n_scale = pid_n * SCALE_BLOCK_N + + # shared layouts + if PACKED_BLOCK_K_X <= 256: + SHARED_LAYOUT_X: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + [[256, 16]], [PACKED_BLOCK_M_X, PACKED_BLOCK_K_X], [1, 0] + ) + else: + SHARED_LAYOUT_X: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + [[PACKED_BLOCK_K_X, 16]], [PACKED_BLOCK_M_X, PACKED_BLOCK_K_X], [1, 0] + ) + if PACKED_BLOCK_K_W <= 256: + SHARED_LAYOUT_W: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + [[256, 16]], [PACKED_BLOCK_N_W, PACKED_BLOCK_K_W], [1, 0] + ) + else: + SHARED_LAYOUT_W: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + [[PACKED_BLOCK_K_W, 16]], [PACKED_BLOCK_N_W, PACKED_BLOCK_K_W], [1, 0] + ) + SHARED_LAYOUT_X_SCALES: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + [[256, 16]], [PACKED_BLOCK_M_X, MX_SCALE_BLOCK_K], [1, 0] + ) + if SWIZZLE_MX_SCALE == "GFX1250_SCALE": + SHARED_LAYOUT_W_SCALES: gl.constexpr = gl.SwizzledSharedLayout( + vec=1, per_phase=1, max_phase=1, order=[1, 0] + ) + else: + SHARED_LAYOUT_W_SCALES: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + [[256, 16]], [SCALE_BLOCK_N, PACKED_MX_BLOCK], [1, 0] + ) + + if GatherIndx is None: + x_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=X, + shape=(M, K // X_K_DIVISOR), + strides=(stride_x_m, stride_x_k), + block_shape=(PACKED_BLOCK_M_X, PACKED_BLOCK_K_X), + layout=SHARED_LAYOUT_X, + ) + else: + x_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=X, + shape=(num_tokens, K // X_K_DIVISOR), + strides=(stride_x_m, stride_x_k), + block_shape=(PACKED_BLOCK_M_X, PACKED_BLOCK_K_X), + layout=SHARED_LAYOUT_X, + ) + w_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=W, + shape=(N, K // W_K_DIVISOR), + strides=(stride_w_n, stride_w_k), + block_shape=(PACKED_BLOCK_N_W, PACKED_BLOCK_K_W), + layout=SHARED_LAYOUT_W, + ) + if GatherIndx is None: + x_scales_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=XMxScale, + shape=(M, gl.cdiv(K, MX_PACK_DIVISOR)), + strides=(stride_x_mx_m, stride_x_mx_k), + block_shape=(PACKED_BLOCK_M_X, MX_SCALE_BLOCK_K), + layout=SHARED_LAYOUT_X_SCALES, + ) + else: + x_scales_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=XMxScale, + shape=(num_tokens, gl.cdiv(K, MX_PACK_DIVISOR)), + strides=(stride_x_mx_m, stride_x_mx_k), + block_shape=(PACKED_BLOCK_M_X, MX_SCALE_BLOCK_K), + layout=SHARED_LAYOUT_X_SCALES, + ) + w_scales_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=WMxScale, + shape=(N // PRESHUFFLE_FACTOR, gl.cdiv(K, MX_PACK_DIVISOR) * PRESHUFFLE_FACTOR), + strides=(stride_w_mx_n, stride_w_mx_k), + block_shape=(SCALE_BLOCK_N, PACKED_MX_BLOCK), + layout=SHARED_LAYOUT_W_SCALES, + ) + + x_buffer = gl.allocate_shared_memory( + x_desc.dtype, shape=[NUM_BUFFERS] + x_desc.block_shape, layout=x_desc.layout + ) + w_buffer = gl.allocate_shared_memory( + w_desc.dtype, shape=[NUM_BUFFERS] + w_desc.block_shape, layout=w_desc.layout + ) + x_scales_buffer = gl.allocate_shared_memory( + x_scales_desc.dtype, shape=[NUM_BUFFERS] + x_scales_desc.block_shape, layout=x_scales_desc.layout + ) + w_scales_buffer = gl.allocate_shared_memory( + w_scales_desc.dtype, shape=[NUM_BUFFERS] + w_scales_desc.block_shape, layout=w_scales_desc.layout + ) + + load_idx = 0 + wmma_idx = 0 + + # prologue: fill NUM_BUFFERS-1 LDS slots via TDM + for _ in gl.static_range(NUM_BUFFERS - 1): + if GatherIndx is None: + gl.amd.gfx1250.tdm.async_load( + x_desc, + [offs_x_m.to(index_type), load_idx * PACKED_BLOCK_K_X], + x_buffer.index(load_idx % NUM_BUFFERS), + ) + else: + gl.amd.gfx1250.tdm.async_gather( + x_desc, + offs_x_m.to(index_type), + load_idx * PACKED_BLOCK_K_X, + x_buffer.index(load_idx % NUM_BUFFERS), + ) + gl.amd.gfx1250.tdm.async_load( + w_desc, + [offs_w_n.to(index_type), load_idx * PACKED_BLOCK_K_W], + w_buffer.index(load_idx % NUM_BUFFERS), + ) + if GatherIndx is None: + gl.amd.gfx1250.tdm.async_load( + x_scales_desc, + [offs_x_m.to(index_type), load_idx * MX_SCALE_BLOCK_K], + x_scales_buffer.index(load_idx % NUM_BUFFERS), + ) + else: + gl.amd.gfx1250.tdm.async_gather( + x_scales_desc, + offs_x_m.to(index_type), + load_idx * MX_SCALE_BLOCK_K, + x_scales_buffer.index(load_idx % NUM_BUFFERS), + ) + gl.amd.gfx1250.tdm.async_load( + w_scales_desc, + [offs_w_n_scale.to(index_type), load_idx * PACKED_MX_BLOCK], + w_scales_buffer.index(load_idx % NUM_BUFFERS), + ) + load_idx += 1 + + # preload tile 0 from LDS into registers + gl.amd.gfx1250.tdm.async_wait((NUM_BUFFERS - 2) * NUM_LOADS_IN_BATCH) + cur_x = x_buffer.index(wmma_idx % NUM_BUFFERS).load(layout=DOT_LAYOUT_X) + cur_w = ( + w_buffer.index(wmma_idx % NUM_BUFFERS).permute((1, 0)).load(layout=DOT_LAYOUT_W) + ) + cur_x_scales = x_scales_buffer.index(wmma_idx % NUM_BUFFERS).load(layout=DOT_LAYOUT_X_SCALES) + if SWIZZLE_MX_SCALE == "GFX1250_SCALE": + cur_w_scales = ( + w_scales_buffer.index(wmma_idx % NUM_BUFFERS).reshape( + ( + SCALE_BLOCK_N, + MX_SCALE_BLOCK_K // SCALE_KWIDTH, + PRESHUFFLE_FACTOR // 4, + 4, + SCALE_KWIDTH, + ) + ) + .permute((0, 3, 2, 1, 4)) + .reshape((BLOCK_N, MX_SCALE_BLOCK_K)) + ).load(layout=DOT_LAYOUT_W_SCALES) + else: + cur_w_scales = w_scales_buffer.index(wmma_idx % NUM_BUFFERS).load(layout=DOT_LAYOUT_W_SCALES) + wmma_idx += 1 + + # main loop: perform wmma and fill LDS with next tile + num_k_iter = gl.cdiv(K, BLOCK_K) + acc = gl.zeros((BLOCK_M, BLOCK_N), dtype=gl.float32, layout=WMMA_LAYOUT) + for _ in range(num_k_iter - (NUM_BUFFERS - 1)): + # issue wmma + acc = gl.amd.gfx1250.wmma_scaled( + cur_x, cur_x_scales, "e2m1", cur_w, cur_w_scales, "e2m1", acc + ) + + # fill next tile to LDS + if GatherIndx is None: + gl.amd.gfx1250.tdm.async_load( + x_desc, + [offs_x_m.to(index_type), load_idx * PACKED_BLOCK_K_X], + x_buffer.index(load_idx % NUM_BUFFERS), + ) + else: + gl.amd.gfx1250.tdm.async_gather( + x_desc, + offs_x_m.to(index_type), + load_idx * PACKED_BLOCK_K_X, + x_buffer.index(load_idx % NUM_BUFFERS), + ) + gl.amd.gfx1250.tdm.async_load( + w_desc, + [offs_w_n.to(index_type), load_idx * PACKED_BLOCK_K_W], + w_buffer.index(load_idx % NUM_BUFFERS), + ) + if GatherIndx is None: + gl.amd.gfx1250.tdm.async_load( + x_scales_desc, + [offs_x_m.to(index_type), load_idx * MX_SCALE_BLOCK_K], + x_scales_buffer.index(load_idx % NUM_BUFFERS), + ) + else: + gl.amd.gfx1250.tdm.async_gather( + x_scales_desc, + offs_x_m.to(index_type), + load_idx * MX_SCALE_BLOCK_K, + x_scales_buffer.index(load_idx % NUM_BUFFERS), + ) + gl.amd.gfx1250.tdm.async_load( + w_scales_desc, + [offs_w_n_scale.to(index_type), load_idx * PACKED_MX_BLOCK], + w_scales_buffer.index(load_idx % NUM_BUFFERS), + ) + load_idx += 1 + + # wait for next tile to be filled + gl.amd.gfx1250.tdm.async_wait((NUM_BUFFERS - 2) * NUM_LOADS_IN_BATCH) + + # load next tile from LDS into registers + next_x = x_buffer.index(wmma_idx % NUM_BUFFERS).load(layout=DOT_LAYOUT_X) + next_w = ( + w_buffer.index(wmma_idx % NUM_BUFFERS) + .permute((1, 0)) + .load(layout=DOT_LAYOUT_W) + ) + next_x_scales = x_scales_buffer.index(wmma_idx % NUM_BUFFERS).load(layout=DOT_LAYOUT_X_SCALES) + if SWIZZLE_MX_SCALE == "GFX1250_SCALE": + next_w_scales = ( + w_scales_buffer.index(wmma_idx % NUM_BUFFERS).reshape( + ( + SCALE_BLOCK_N, + MX_SCALE_BLOCK_K // SCALE_KWIDTH, + PRESHUFFLE_FACTOR // 4, + 4, + SCALE_KWIDTH, + ) + ) + .permute((0, 3, 2, 1, 4)) + .reshape((BLOCK_N, MX_SCALE_BLOCK_K)) + ).load(layout=DOT_LAYOUT_W_SCALES) + else: + next_w_scales = w_scales_buffer.index(wmma_idx % NUM_BUFFERS).load(layout=DOT_LAYOUT_W_SCALES) + wmma_idx += 1 + + # prepare next iteration + cur_x = next_x + cur_w = next_w + cur_x_scales = next_x_scales + cur_w_scales = next_w_scales + + # epilogue: drain remaining tiles + for k_ep in gl.static_range(NUM_BUFFERS - 2): + # issue wmma + acc = gl.amd.gfx1250.wmma_scaled( + cur_x, cur_x_scales, "e2m1", cur_w, cur_w_scales, "e2m1", acc + ) + + # wait for next tile to be filled + gl.amd.gfx1250.tdm.async_wait((NUM_BUFFERS - 3 - k_ep) * NUM_LOADS_IN_BATCH) + + # load next tile from LDS into registers + next_x = x_buffer.index(wmma_idx % NUM_BUFFERS).load(layout=DOT_LAYOUT_X) + next_w = ( + w_buffer.index(wmma_idx % NUM_BUFFERS) + .permute((1, 0)) + .load(layout=DOT_LAYOUT_W) + ) + next_x_scales = x_scales_buffer.index(wmma_idx % NUM_BUFFERS).load(layout=DOT_LAYOUT_X_SCALES) + if SWIZZLE_MX_SCALE == "GFX1250_SCALE": + next_w_scales = ( + w_scales_buffer.index(wmma_idx % NUM_BUFFERS).reshape( + ( + SCALE_BLOCK_N, + MX_SCALE_BLOCK_K // SCALE_KWIDTH, + PRESHUFFLE_FACTOR // 4, + SCALE_KWIDTH, + 4, + ) + ) + .permute((0, 3, 2, 1, 4)) + .reshape((BLOCK_N, MX_SCALE_BLOCK_K)) + ).load(layout=DOT_LAYOUT_W_SCALES) + else: + next_w_scales = w_scales_buffer.index(wmma_idx % NUM_BUFFERS).load(layout=DOT_LAYOUT_W_SCALES) + wmma_idx += 1 + + # prepare next iteration + cur_x = next_x + cur_w = next_w + cur_x_scales = next_x_scales + cur_w_scales = next_w_scales + + # issue last wmma + acc = gl.amd.gfx1250.wmma_scaled( + cur_x, cur_x_scales, "e2m1", cur_w, cur_w_scales, "e2m1", acc + ) + + # bias + offs_m = BLOCK_M * block_id + gl.arange( + 0, BLOCK_M, layout=gl.SliceLayout(1, WMMA_LAYOUT) + ) + offs_y_n = BLOCK_N * pid_n + gl.arange( + 0, BLOCK_N, layout=gl.SliceLayout(0, WMMA_LAYOUT) + ) + mask_m = offs_m < M + mask_n = offs_y_n < N + if B is not None: + BPtrs = B + expt_id * stride_b_e + bias = gl.amd.gfx1250.buffer_load(BPtrs, offs_y_n, mask=mask_n) + acc = acc + bias[None, :] + + # apply activation function + if APPLY_SWIGLU and SPLIT_K == 1: + if SWIZZLE_MX_SCALE == "GFX1250_SCALE": + SWIGLU_LAYOUT: gl.constexpr = gl.BlockedLayout( + size_per_thread=[1, 2], + threads_per_warp=[16, 2], + warps_per_cta=[NUM_WARPS, 1], + order=[1, 0], + ) + acc = gl.convert_layout(acc, SWIGLU_LAYOUT) + out = _swiglu(acc, alpha, limit, ADD_RESIDUAL) + out = gl.convert_layout(out, WMMA_LAYOUT) + gl.static_assert( + out.shape[1] == OUT_BLOCK_N, + f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})", + ) + offs_y_n = OUT_BLOCK_N * pid_n + gl.arange( + 0, OUT_BLOCK_N, layout=gl.SliceLayout(0, WMMA_LAYOUT) + ) + mask_n = offs_y_n < yN + else: + gl.static_assert( + ACTIVATION_REDUCTION_N == 1, + "Activation reduction must be 1 if no activation fn is provided", + ) + out = acc + + # apply gammas + if Gammas is not None: + gammas = gl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0) + out *= gammas[:, None] + + # write-back + Y += start_m * stride_y_m + offs_y_m = offs_m + offs_y = ( + offs_y_m.to(index_type)[:, None] * stride_y_m + + offs_y_n.to(index_type)[None, :] * stride_y_n + ) + mask = mask_m[:, None] & mask_n[None, :] + out = out.to(gl.bfloat16) + gl.amd.gfx1250.buffer_store(out, Y, offs_y, mask=mask) diff --git a/aiter/ops/triton/moe/moe_op_gemm_a4w4.py b/aiter/ops/triton/moe/moe_op_gemm_a4w4.py index 2d3abb87f9..80420742e6 100644 --- a/aiter/ops/triton/moe/moe_op_gemm_a4w4.py +++ b/aiter/ops/triton/moe/moe_op_gemm_a4w4.py @@ -4,13 +4,27 @@ import itertools import torch import triton +from triton.experimental import gluon +import triton.experimental.gluon.language as gl + from aiter.ops.triton.moe.moe_routing.routing import RoutingData from aiter.ops.triton._triton_kernels.moe.moe_op_gemm_a4w4 import ( _mxfp4_quant_kernel, _moe_gemm_a4w4, ) -from aiter.ops.triton.moe.reduce import reduce_grouped +from aiter.ops.triton._gluon_kernels.moe.moe_op_gemm_a4w4 import ( + _moe_gemm_a4w4_gfx1250, +) from aiter.ops.triton.utils._triton.arch_info import get_arch +from aiter.ops.triton.moe.reduce import reduce_grouped + +GLUON_SUPPORTED_ARCHS = set(["gfx1250"]) + + +def is_gluon_supported(): + arch = get_arch() + return arch in GLUON_SUPPORTED_ARCHS + # ----------------------------------------------------------------------------- # Matrix Multiplication + Outer Gather/Scatter @@ -65,10 +79,10 @@ def allocate_output( return matmul_output, final_output -def get_kernel_config(m, n, k, routing_data): +def get_kernel_config_triton(m, n, k, routing_data): block_m = routing_data.block_m group_m = 4 - num_xcds = 8 + num_xcds = 1 xcd_swizzle = num_xcds w_cache_modifier = ".cg" if block_m <= 32 else None arch = get_arch() @@ -79,7 +93,6 @@ def get_kernel_config(m, n, k, routing_data): block_n = 128 block_k = 256 num_warps = 4 - grid_m = routing_data.n_blocks(m, block_m) grid_n = triton.cdiv(n, block_n) grid = grid_m * grid_n * split_k @@ -111,7 +124,29 @@ def get_kernel_config(m, n, k, routing_data): return ret -def swizzle_scales(data): +def get_kernel_config_gluon(m, n, k, routing_data): + block_m = routing_data.block_m + num_xcds = 1 + num_buffers = 3 + + block_n = 128 + block_k = 512 + num_warps = 4 + + ret = { + "block_m": block_m, + "block_n": block_n, + "block_k": block_k, + "num_warps": num_warps, + "xcd_swizzle": num_xcds, + "num_buffers": num_buffers, + "split_k": 1, + "waves_per_eu": 0, + } + return ret + + +def swizzle_scales_gfx950(data): NON_K_PRESHUFFLE_BLOCK_SIZE = 32 block_shape = data.shape SCALE_K = block_shape[-2] @@ -124,6 +159,47 @@ def swizzle_scales(data): return data.transpose(-1, -2) +def swizzle_scales_gfx1250(data): + NON_K_PRESHUFFLE_BLOCK_SIZE = 128 + block_shape = data.shape + SCALE_K = block_shape[-2] + N = block_shape[-1] + num_chunk_m = N // NON_K_PRESHUFFLE_BLOCK_SIZE + SCALE_KWIDTH = 4 if SCALE_K >= 4 else SCALE_K + num_chunk_k = SCALE_K // SCALE_KWIDTH + data = data.transpose(-1, -2) + data = data.view( + -1, num_chunk_m, 4, NON_K_PRESHUFFLE_BLOCK_SIZE // 4, num_chunk_k, SCALE_KWIDTH + ) + data = data.permute(0, 1, 4, 3, 2, 5).contiguous() + E = block_shape[0] + data = data.view( + E, N // NON_K_PRESHUFFLE_BLOCK_SIZE, SCALE_K * NON_K_PRESHUFFLE_BLOCK_SIZE + ) + return data.transpose(-1, -2) + + +@gluon.constexpr_function +def get_wmma_layout(num_warps, packed, scale_preshuffle): + assert num_warps in (4, 8) + if scale_preshuffle: + reg_bases = [[0, 1], [1, 0]] + tiles_per_warp = 2 + else: + reg_bases = [] + tiles_per_warp = 1 + + # [NUM_WARPS // 2, 2] + if num_warps == 4: + warp_bases = [[0, tiles_per_warp], [tiles_per_warp, 0]] + else: + warp_bases = [[0, tiles_per_warp], [0, tiles_per_warp * 2], [tiles_per_warp, 0]] + + instr_shape = [16, 16, 64] if packed else [16, 16, 128] + + return gl.amd.AMDWMMALayout(3, True, warp_bases, reg_bases, instr_shape) + + # ----------------------------------------------------------------------------- # Triton Implementation # ----------------------------------------------------------------------------- @@ -204,12 +280,25 @@ def moe_gemm_a4w4( add_residual=True, unpadded_N=None, unpadded_K=None, + config=None, + backend=None, # "triton" | "gluon" ): """ Y[:, :] = 0. for e in num_experts: Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :]) """ + backend = ( + backend + if backend is not None + else ("gluon" if is_gluon_supported() else "triton") + ) + assert backend in ["triton", "gluon"], f"Invalid backend: {backend}" + if backend == "gluon": + assert ( + is_gluon_supported() + ), f"Gluon backend is not supported on this architecture: {get_arch()}" + use_gluon = backend == "gluon" assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp" x_has_mx = x_scales is not None if x_has_mx: @@ -221,6 +310,7 @@ def moe_gemm_a4w4( stride_x_mx_m = 0 stride_x_mx_k = 0 # determine shapes + num_tokens = x.shape[-2] M = x.shape[-2] if gather_indx is None else gather_indx.shape[0] K, N = x.shape[-1] * 2, w.shape[-1] block_m = routing_data.block_m @@ -229,7 +319,11 @@ def moe_gemm_a4w4( if unpadded_K and block_m == 16: K = unpadded_K # compute optimization flags - config = get_kernel_config(M, N, K, routing_data) + if config is None: + if use_gluon: + config = get_kernel_config_gluon(M, N, K, routing_data) + else: + config = get_kernel_config_triton(M, N, K, routing_data) if apply_swiglu and config["split_k"] > 1: apply_swiglu_matmul = False reduction_n_matmul = 1 @@ -270,62 +364,126 @@ def moe_gemm_a4w4( grid_n = triton.cdiv(N, config["block_n"]) grid = grid_m * grid_n * config["split_k"] # launch kernel - _moe_gemm_a4w4[(grid,)]( - y, - y.stride(0), - y.stride(1), - y.stride(2), - x, - x.stride(0), - x.stride(1), - x_scales, - stride_x_mx_m, - stride_x_mx_k, - w, - w.stride(0), - w.stride(1), - w.stride(2), - w_scales, - w_scales.stride(0), - w_scales.stride(1), - w_scales.stride(2), - x_static_scale, - quant_static_scale, - bias, - stride_bias, - gammas, - N, - K, - gather_indx, - expt_hist, - expt_token_offs_raw, - expt_hist_sum, - expt_block_pid_map, - grid_m, - grid_n, - apply_swiglu_matmul, - alpha, - limit, - reduction_n_matmul, - add_residual, - routing_data.n_expts_act, - config["block_m"], - config["block_n"], - config["block_k"], - config["group_m"], - XCD_SWIZZLE=config["xcd_swizzle"], - SWIZZLE_MX_SCALE=swizzle_mx_scale, - SPLIT_K=config["split_k"], - EVEN_K=K % config["block_k"] == 0, - MASK_K_LIMIT=K % config["block_k"], - W_CACHE_MODIFIER=config["w_cache_modifier"], - num_warps=config["num_warps"], - num_stages=config["num_stages"], - UPCAST_INDICES=should_upcast_indices(x, w, y), - waves_per_eu=config["waves_per_eu"], - matrix_instr_nonkdim=config["matrix_instr_nonkdim"], - kpack=config["kpack"], - ) + if use_gluon and get_arch() == "gfx1250": + # layouts + WMMA_LAYOUT = get_wmma_layout( + config["num_warps"], False, swizzle_mx_scale == "GFX1250_SCALE" + ) + WMMA_LAYOUT_PACKED = get_wmma_layout( + config["num_warps"], True, swizzle_mx_scale == "GFX1250_SCALE" + ) + assert ( + config["split_k"] == 1 + ), "Split-k is not supported for Gluon backend on gfx1250" + _moe_gemm_a4w4_gfx1250[(grid,)]( + y, + y.stride(0), + y.stride(1), + y.stride(2), + x, + x.stride(0), + x.stride(1), + x_scales, + stride_x_mx_m, + stride_x_mx_k, + w, + w.stride(0), + w.stride(1), + w.stride(2), + w_scales, + w_scales.stride(0), + w_scales.stride(1), + w_scales.stride(2), + bias, + stride_bias, + gammas, + num_tokens, + N, + K, + gather_indx, + expt_hist, + expt_token_offs_raw, + expt_hist_sum, + expt_block_pid_map, + grid_m, + grid_n, + apply_swiglu_matmul, + alpha, + limit, + reduction_n_matmul, + add_residual, + routing_data.n_expts_act, + config["block_m"], + config["block_n"], + config["block_k"], + SPLIT_K=config["split_k"], + XCD_SWIZZLE=config["xcd_swizzle"], + SWIZZLE_MX_SCALE=swizzle_mx_scale, + NUM_BUFFERS=config["num_buffers"], + UPCAST_INDICES=should_upcast_indices(x, w, y), + WMMA_LAYOUT=WMMA_LAYOUT, + WMMA_LAYOUT_PACKED=WMMA_LAYOUT_PACKED, + NUM_WARPS=config["num_warps"], + num_warps=config["num_warps"], + waves_per_eu=config["waves_per_eu"], + ) + else: + _moe_gemm_a4w4[(grid,)]( + y, + y.stride(0), + y.stride(1), + y.stride(2), + x, + x.stride(0), + x.stride(1), + x_scales, + stride_x_mx_m, + stride_x_mx_k, + w, + w.stride(0), + w.stride(1), + w.stride(2), + w_scales, + w_scales.stride(0), + w_scales.stride(1), + w_scales.stride(2), + x_static_scale, + quant_static_scale, + bias, + stride_bias, + gammas, + N, + K, + gather_indx, + expt_hist, + expt_token_offs_raw, + expt_hist_sum, + expt_block_pid_map, + grid_m, + grid_n, + apply_swiglu_matmul, + alpha, + limit, + reduction_n_matmul, + add_residual, + routing_data.n_expts_act, + config["block_m"], + config["block_n"], + config["block_k"], + config["group_m"], + XCD_SWIZZLE=config["xcd_swizzle"], + SWIZZLE_MX_SCALE=swizzle_mx_scale, + SPLIT_K=config["split_k"], + EVEN_K=K % config["block_k"] == 0, + MASK_K_LIMIT=K % config["block_k"], + W_CACHE_MODIFIER=config["w_cache_modifier"], + num_warps=config["num_warps"], + num_stages=config["num_stages"], + UPCAST_INDICES=should_upcast_indices(x, w, y), + waves_per_eu=config["waves_per_eu"], + matrix_instr_nonkdim=config["matrix_instr_nonkdim"], + kpack=config["kpack"], + ) # Build grouped reduction inputs in a uniform way group_indx = ( None diff --git a/aiter/ops/triton/moe/moe_routing/routing.py b/aiter/ops/triton/moe/moe_routing/routing.py index b59221a479..204558ce0d 100644 --- a/aiter/ops/triton/moe/moe_routing/routing.py +++ b/aiter/ops/triton/moe/moe_routing/routing.py @@ -1,4 +1,5 @@ import torch +from aiter.ops.triton.utils._triton.arch_info import get_arch import triton from dataclasses import dataclass, field from aiter.ops.triton._triton_kernels.moe.moe_routing.routing import ( @@ -256,7 +257,14 @@ def routing(logits, n_expts_act, sm_first=False): num_tokens, n_expts_tot = logits.shape m = num_tokens * n_expts_act tokens_per_expt = max(1, m // n_expts_tot) - block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) + if get_arch() == "gfx1250": + # GFX1250 block_m + if m == 32: + block_m = 32 + else: + block_m = 256 + else: + block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) if num_tokens <= 16: HIST_BLOCK_M = triton.next_power_of_2(num_tokens) ( diff --git a/op_tests/op_benchmarks/triton/bench_moe_gemm_a4w4.py b/op_tests/op_benchmarks/triton/bench_moe_gemm_a4w4.py index 064b3f5933..849a306fd6 100644 --- a/op_tests/op_benchmarks/triton/bench_moe_gemm_a4w4.py +++ b/op_tests/op_benchmarks/triton/bench_moe_gemm_a4w4.py @@ -11,7 +11,8 @@ from aiter.ops.triton.moe.moe_op_gemm_a4w4 import ( mxfp4_quant, moe_gemm_a4w4, - swizzle_scales, + swizzle_scales_gfx950, + swizzle_scales_gfx1250, ) from aiter.ops.triton.utils._triton.arch_info import get_arch import tempfile @@ -130,9 +131,12 @@ def inject_proxy_and_call(val, args, kwargs): def check_and_swizzle_scales(scale, N, K): - if N % 32 == 0 and K % (32 * 8) == 0: - scale = swizzle_scales(scale) + if get_arch() == "gfx950" and N % 32 == 0 and K % (32 * 8) == 0: + scale = swizzle_scales_gfx950(scale) return scale, "CDNA4_SCALE" + elif get_arch() == "gfx1250" and N % 128 == 0 and K % (32 * 4) == 0: + scale = swizzle_scales_gfx1250(scale) + return scale, "GFX1250_SCALE" else: return scale, None @@ -191,7 +195,6 @@ def bench_mlp_single_weight_init( # -- benchmark -- x_dtype_str = x_dtype - x_dtype = torch.float8_e4m3fn reps = 100 x = torch.randn((batch, dim1), dtype=torch.bfloat16, device=dev) @@ -214,7 +217,7 @@ def bench_mlp_single_weight_init( b1, rdata, gather_indx=gather_indx, - swizzle_mx_scale="CDNA4_SCALE", + swizzle_mx_scale=swizzle_mx_scale1, apply_swiglu=True, ) x, x_scale = mxfp4_quant(x) @@ -228,7 +231,7 @@ def bench_mlp_single_weight_init( b2, rdata, scatter_indx=scatter_indx, - swizzle_mx_scale="CDNA4_SCALE", + swizzle_mx_scale=swizzle_mx_scale2, ) proton.finalize() return parse_profile( diff --git a/op_tests/triton_tests/moe/test_moe_gemm_a4w4.py b/op_tests/triton_tests/moe/test_moe_gemm_a4w4.py index 6a10fe1ac1..7805983215 100644 --- a/op_tests/triton_tests/moe/test_moe_gemm_a4w4.py +++ b/op_tests/triton_tests/moe/test_moe_gemm_a4w4.py @@ -10,10 +10,12 @@ # matmul utilities from aiter.ops.triton.moe.moe_op_gemm_a4w4 import ( + is_gluon_supported, mxfp4_quant, moe_gemm_a4w4, moe_gemm_torch, - swizzle_scales, + swizzle_scales_gfx950, + swizzle_scales_gfx1250, ) # numerics utilities @@ -68,7 +70,8 @@ def init_compute_data( has_y_gammas, device="cuda", ): - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) in_m = m * (n_expts_act if gindx is None else 1) shape_x = (in_m, k) x = alloc_rand(shape_x, device=device, dtype=act_dtype) @@ -212,7 +215,7 @@ class Case: ) @pytest.mark.parametrize("has_y_gammas", [False, True]) @pytest.mark.parametrize("apply_swiglu", [False, True]) -@pytest.mark.parametrize("fused_quant", [False, True]) +@pytest.mark.parametrize("backend", ["triton", "gluon"]) def test_op( m, n, @@ -221,23 +224,29 @@ def test_op( do_scatter, has_y_gammas, apply_swiglu, - fused_quant, n_expts_tot, n_expts_act, hbm_swizzling, + backend, device="cuda", ): - if get_arch() != "gfx950": - pytest.skip("FP4 kernels are not supported on MI300.") if hbm_swizzling: - if n % 32 != 0 or k % (32 * 8) != 0: + if get_arch() == "gfx950" and (m % 32 != 0 or k % (32 * 8) != 0): + pytest.skip( + f"Shape {m}x{n}x{k} is not supported for scale swizzling on AMD GPU" + ) + elif get_arch() == "gfx1250" and k % (32 * 4) != 0: pytest.skip( f"Shape {m}x{n}x{k} is not supported for scale swizzling on AMD GPU" ) - torch.manual_seed(0) + # skip gluon backend if not supported + if backend == "gluon" and not is_gluon_supported(): + pytest.skip(f"Gluon backend is not supported on {get_arch()}") + + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) - act_mxfp4 = "mxfloat4_e2m1" weight_mxfp4 = "mxfloat4_e2m1" weight_dtype_str = weight_mxfp4[2:] @@ -264,8 +273,14 @@ def test_op( w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=1) w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=1) if hbm_swizzling: - swizzle_mx_scale = "CDNA4_SCALE" - w_scale_tri = swizzle_scales(w_scale_tri) + if get_arch() == "gfx1250": + swizzle_mx_scale = "GFX1250_SCALE" + w_scale_tri = swizzle_scales_gfx1250(w_scale_tri) + elif get_arch() == "gfx950": + swizzle_mx_scale = "CDNA4_SCALE" + w_scale_tri = swizzle_scales_gfx950(w_scale_tri) + else: + assert False, "Unsupported architecture" else: swizzle_mx_scale = None @@ -279,17 +294,15 @@ def test_op( ref_y = moe_gemm_torch( x_ref, w_ref, bias_ref, rdata, gindx, sindx, gammas, apply_swiglu ) - if not act_mxfp4 and fused_quant: - quant_static_scale = ref_y.abs().max().float() / 448.0 - else: - quant_static_scale = None + + # run kernel tri_y = moe_gemm_a4w4( x_tri, w_tri, x_mx_scales_tri, w_scale_tri, x_static_scale, - quant_static_scale, + None, bias_tri, rdata, gindx, @@ -298,7 +311,6 @@ def test_op( swizzle_mx_scale, out_dtype, apply_swiglu, + backend=backend, ) - if not act_mxfp4 and fused_quant: - tri_y = (tri_y.float() * quant_static_scale).to(ref_y.dtype) assert_close(ref_y, tri_y, maxtol=maxtol, rmstol=rmstol)