From 73a6354b592ad2d820cc01d64e0d9a7b2063e768 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 6 Jan 2026 14:53:17 +0000 Subject: [PATCH 01/10] add kernel and config --- .../gemm/basic/gemm_a16wfp4.py | 187 ++++++++++++++++- ...EMM-A16WFP4_PRESHUFFLED-N=2112-K=7168.json | 86 ++++++++ .../gemm/gfx950-GEMM-A16WFP4_PRESHUFFLED.json | 14 ++ aiter/ops/triton/gemm/basic/gemm_a16wfp4.py | 188 ++++++++++++++++++ .../gemm/basic/test_gemm_a16wfp4.py | 66 +++++- 5 files changed, 527 insertions(+), 14 deletions(-) create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-A16WFP4_PRESHUFFLED-N=2112-K=7168.json create mode 100644 aiter/ops/triton/configs/gemm/gfx950-GEMM-A16WFP4_PRESHUFFLED.json diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py index 72ecb7f953..02f902447f 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py @@ -179,8 +179,6 @@ def _gemm_a16wfp4_kernel( + pid_k * stride_ck ) c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - # if pid == 0: - # tl.device_print("c", c) if ATOMIC_ADD: tl.atomic_add(c_ptrs, c, mask=c_mask, sem="relaxed") else: @@ -204,6 +202,186 @@ def _gemm_a16wfp4_kernel( ) +@triton.heuristics( + { + "EVEN_K": lambda args: (args["K"] % (args["BLOCK_SIZE_K"] // 2) == 0) + and (args["SPLITK_BLOCK_SIZE"] % args["BLOCK_SIZE_K"] == 0) + and (args["K"] % (args["SPLITK_BLOCK_SIZE"] // 2) == 0), + "GRID_MN": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"]) + * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), + } +) +@triton.jit(repr=_gemm_a16wfp4_preshuffle_repr) +def _gemm_a16wfp4_preshuffle_kernel( + a_ptr, + b_ptr, + c_ptr, + b_scales_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bn, + stride_bk, + stride_ck, + stride_cm, + stride_cn, + stride_bsn, + stride_bsk, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, + EVEN_K: tl.constexpr, + num_warps: tl.constexpr, + num_stages: tl.constexpr, + waves_per_eu: tl.constexpr, + matrix_instr_nonkdim: tl.constexpr, + GRID_MN: tl.constexpr, + PREQUANT: tl.constexpr, + cache_modifier: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A and B inputs are in the microscale fp4 (mxfp4) format. + A_scales and B_scales are in e8m0 format. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + tl.assume(stride_bsk > 0) + tl.assume(stride_bsn > 0) + + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid_unified = tl.program_id(axis=0) + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + if NUM_KSPLIT == 1: + pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(pid_k >= 0) + + # We assume 32 elements along K share the same scale. + SCALE_GROUP_SIZE: tl.constexpr = 32 + + if (pid_k * SPLITK_BLOCK_SIZE // 2) < K: + + num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE // 2, BLOCK_SIZE_K // 2) + + # Create pointers for first block of A and B input matrices + # The BLOCK sizes are of the elements and in fp4 we pack 2 per uint8 container. + offs_k_bf16 = tl.arange(0, BLOCK_SIZE_K) + offs_k_split_bf16 = pid_k * SPLITK_BLOCK_SIZE + offs_k_bf16 + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k_split_bf16[None, :] * stride_ak + ) + + # offs_k = tl.arange(0, BLOCK_SIZE_K // 2) + # offs_k_split = pid_k * (SPLITK_BLOCK_SIZE // 2) + offs_k + offs_k_shuffle_arr = tl.arange(0, (BLOCK_SIZE_K // 2) * 16) + offs_k_shuffle = pid_k * (SPLITK_BLOCK_SIZE // 2) * 16 + offs_k_shuffle_arr + offs_bn = (pid_n * (BLOCK_SIZE_N // 16) + tl.arange(0, BLOCK_SIZE_N // 16)) % N + b_ptrs = b_ptr + ( + offs_bn[:, None] * stride_bn + offs_k_shuffle[None, :] * stride_bk + ) + # Create pointers for the first block of A and B scales + offs_bsn = ( + pid_n * (BLOCK_SIZE_N // 32) + tl.arange(0, (BLOCK_SIZE_N // 32)) + ) % N + offs_ks = (pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE) * 32) + tl.arange( + 0, BLOCK_SIZE_K // SCALE_GROUP_SIZE * 32 + ) + # B scales are N x K even though B operand is K x N. + b_scale_ptrs = ( + b_scales_ptr + + offs_bsn[:, None] * stride_bsn + + offs_ks[None, :] * stride_bsk + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): + b_scales = ( + tl.load(b_scale_ptrs, cache_modifier=cache_modifier) + .reshape( + BLOCK_SIZE_N // 32, + BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8, + 4, + 16, + 2, + 2, + 1, + ) + .permute(0, 5, 3, 1, 4, 2, 6) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + + # a_scales = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K//SCALE_GROUP_SIZE), 127, dtype=tl.uint8) + # b_scales = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K//SCALE_GROUP_SIZE), 127, dtype=tl.uint8) + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + if EVEN_K: + a_bf16 = tl.load(a_ptrs) + b = tl.load(b_ptrs, cache_modifier=cache_modifier) + + b = ( + b.reshape( + 1, + BLOCK_SIZE_N // 16, + BLOCK_SIZE_K // 64, + 2, + 16, + 16, + ) + .permute(0, 1, 4, 2, 3, 5) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // 2) + .trans(1, 0) + ) + + if PREQUANT: + a, a_scales = _mxfp4_quant_op(a_bf16, BLOCK_SIZE_K, BLOCK_SIZE_M, 32) + + accumulator += tl.dot_scaled(a, a_scales, "e2m1", b, b_scales, "e2m1") + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 2) * 16 * stride_bk + b_scale_ptrs += BLOCK_SIZE_K * stride_bsk + + c = accumulator.to(c_ptr.type.element_ty) + + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + c_ptrs = ( + c_ptr + + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :] + + pid_k * stride_ck + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + def get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int): # heuristics for make "EVEN_K == True" as much as possible NUM_KSPLIT_STEP = 2 @@ -241,6 +419,9 @@ def _get_config( M: int, N: int, K: int, + shuffle: bool = False, ): + shuffle_suffix = "_PRESHUFFLED" if shuffle else "" + config_name = f"GEMM-A16WFP4{shuffle_suffix}" # Note: Config files use K=2*K in their naming - return get_gemm_config("GEMM-A16WFP4", M, N, 2 * K) + return get_gemm_config(config_name, M, N, 2 * K) diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16WFP4_PRESHUFFLED-N=2112-K=7168.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16WFP4_PRESHUFFLED-N=2112-K=7168.json new file mode 100644 index 0000000000..ba0d492a1c --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16WFP4_PRESHUFFLED-N=2112-K=7168.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_8": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 14 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 14 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 14 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 14 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16WFP4_PRESHUFFLED.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16WFP4_PRESHUFFLED.json new file mode 100644 index 0000000000..01951d60ce --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-A16WFP4_PRESHUFFLED.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py b/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py index 345090b982..b8fa8313f9 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py +++ b/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py @@ -9,6 +9,7 @@ from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_str from aiter.ops.triton._triton_kernels.gemm.basic.gemm_a16wfp4 import ( _gemm_a16wfp4_kernel, + _gemm_a16wfp4_preshuffle_kernel, _get_config, ) from aiter.ops.triton._triton_kernels.gemm.basic.gemm_afp4wfp4 import ( @@ -184,3 +185,190 @@ def gemm_a16wfp4( ) -> torch.Tensor: config_hashable = serialize_dict(config) if config else None return gemm_a16wfp4_(x, w, w_scales, atomic_add, dtype, y, config_hashable) + + +def gemm_a16wfp4_preshuffle_fake_tensor( + x: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[str] = None, + skip_reduce: Optional[bool] = False, +) -> torch.Tensor: + M, K = x.shape + N, _ = w.shape + + config = deserialize_str(config) + + num_ksplit = config["NUM_KSPLIT"] + block_size_k = config["BLOCK_SIZE_K"] + + if num_ksplit > 1: + _, block_size_k, num_ksplit = get_splitk(K, block_size_k, num_ksplit) + + if block_size_k >= 2 * K: + num_ksplit = 1 + + if num_ksplit > 1 and skip_reduce: + y_pp = torch.empty((num_ksplit, M, N), dtype=torch.float32, device=x.device) + return y_pp + + return torch.empty((M, N), dtype=dtype, device=x.device) + + +@torch_compile_guard(gen_fake=gemm_a16wfp4_preshuffle_fake_tensor) +def gemm_a16wfp4_preshuffle_( + x: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, + prequant: Optional[bool] = True, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[str] = None, + skip_reduce: Optional[bool] = False, +) -> torch.Tensor: + """ + Computes matrix multiplication Y = X @ W^T with BF16 activations and FP4 weights. + + Key parameters: + x (torch.Tensor): BF16/FP16 input matrix X with shape (M, K). + Quantized to MXFP4 on-the-fly during GEMM. + w (torch.Tensor): FP4 E2M1 weight matrix W with shape (N, K//2). + w_scales (torch.Tensor): E8M0 per-group scale for w with shape (M//32, K). + One scale per 32 elements in K dimension. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). + config (Optional[str]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). + skip_reduce (Optional[bool]): skip reduction, y becomes (SPK, M, N) where SPK is determined by config + + Returns: + y (torch.Tensor): Output with shape (M, N). + """ + + _LOGGER.info( + f"GEMM_A16WFP4_PRESHUFFLE: x={tuple(x.shape)} w={tuple(w.shape)} w_scale={tuple(w_scales.shape)} " + ) + + assert arch_info.is_fp4_avail(), "MXFP4 is not available on your device" + assert prequant == True, "prequant == False is not supported yet" + + M, K = x.shape + N, K = w.shape + N = N * 16 + K = K // 16 + + if config is None: + config, _ = _get_config(M, N, K, True) + else: + config = deserialize_str(config) + + if config["NUM_KSPLIT"] > 1: + SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( + K, config["BLOCK_SIZE_K"], config["NUM_KSPLIT"] + ) + + config["SPLITK_BLOCK_SIZE"] = SPLITK_BLOCK_SIZE + config["BLOCK_SIZE_K"] = BLOCK_SIZE_K + config["NUM_KSPLIT"] = NUM_KSPLIT + + if config["BLOCK_SIZE_K"] >= 2 * K: + config["BLOCK_SIZE_K"] = triton.next_power_of_2(2 * K) + config["SPLITK_BLOCK_SIZE"] = 2 * K + config["NUM_KSPLIT"] = 1 + config["BLOCK_SIZE_N"] = max(config["BLOCK_SIZE_N"], 32) + + return_y_pp = config["NUM_KSPLIT"] > 1 and skip_reduce + + if config["NUM_KSPLIT"] > 1: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), dtype=torch.float32, device=x.device + ) + else: + config["SPLITK_BLOCK_SIZE"] = 2 * K + y_pp = None + + if y is None and not return_y_pp: + y = torch.empty((M, N), dtype=dtype, device=x.device) + + grid = lambda META: ( # noqa: E731 + ( + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]) + ), + ) + _gemm_a16wfp4_preshuffle_kernel[grid]( + x, + w, + y if y_pp is None else y_pp, + w_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w.stride(0), + w.stride(1), + 0 if y_pp is None else y_pp.stride(0), + y.stride(0) if y_pp is None else y_pp.stride(1), + y.stride(1) if y_pp is None else y_pp.stride(2), + w_scales.stride(0), + w_scales.stride(1), + PREQUANT=prequant, + **config, + ) + + if return_y_pp: + return y_pp + elif config["NUM_KSPLIT"] > 1: + REDUCE_BLOCK_SIZE_M = 16 + REDUCE_BLOCK_SIZE_N = 64 + # TODO: Need to debug - REDUCE_BLOCK_SIZE_N=128 with fp32 partials fails + # NOTE: REDUCE_BLOCK_SIZE_N=16 gives best perf with fp32 partials and + # REDUCE_BLOCK_SIZE_N=128 gives best perf with bf16 partials + ACTUAL_KSPLIT = triton.cdiv(K, (config["SPLITK_BLOCK_SIZE"] // 2)) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N, REDUCE_BLOCK_SIZE_N), + ) + _gemm_afp4wfp4_reduce_kernel[grid_reduce]( + y_pp, + y, + M, + N, + y_pp.stride(0), + y_pp.stride(1), + y_pp.stride(2), + y.stride(0), + y.stride(1), + REDUCE_BLOCK_SIZE_M, + REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT, + triton.next_power_of_2(config["NUM_KSPLIT"]), + ) + + return y + + +def gemm_a16wfp4_preshuffle( + x: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, + prequant: Optional[bool] = True, + dtype: Optional[torch.dtype] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, + skip_reduce: Optional[bool] = False, +) -> torch.Tensor: + if config is None: + config_hashable = None + M, _ = x.shape + N, K = w.shape + N = N * 16 + K = K // 16 + config, _ = _get_config(M, N, K, True) + config_hashable = serialize_dict(config) + return gemm_a16wfp4_preshuffle_(x, w, w_scales, prequant, dtype, y, config_hashable) diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py b/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py index acacce6a15..3c021cec92 100644 --- a/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py +++ b/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py @@ -1,7 +1,12 @@ import torch import pytest -from aiter.ops.triton.gemm.basic.gemm_a16wfp4 import gemm_a16wfp4 +from aiter.ops.triton.gemm.basic.gemm_a16wfp4 import gemm_a16wfp4, gemm_a16wfp4_preshuffle import aiter.ops.triton.utils._triton.arch_info as arch_info +from op_tests.triton_tests.gemm.basic.test_gemm_afp4wfp4 import ( + shuffle_scales, + un_shuffle_scales, +) +from aiter.ops.shuffle import shuffle_weight # Note this is specified by the HW and cannot be changed. SCALE_GROUP_SIZE = 32 @@ -15,6 +20,7 @@ def generate_gemm_a16wfp4_inputs( atomic_add: bool, dtype: bool, layout: str = "TN", + shuffle: bool = False, ): torch.manual_seed(5) # 34 is two packed e2m1 values 0010 which is 1.0. @@ -49,12 +55,27 @@ def generate_gemm_a16wfp4_inputs( ) w_scales = w_scales.T + if shuffle: + use_int4 = False + weight_shuffle_layout = (16, 16) + w_shuffed = shuffle_weight( + w, layout=weight_shuffle_layout, use_int4=use_int4 + ).reshape( + w.shape[0] // weight_shuffle_layout[0], + w.shape[1] * weight_shuffle_layout[0], + ) + + w_scales_shuffled = shuffle_scales(w_scales) + else: + w_shuffed = w + w_scales_shuffled = w_scales + y = None if output: dtype = torch.float32 if atomic_add else dtype y = torch.zeros((M, N), device=x.device, dtype=dtype) - return x, w, x_scales, w_scales, y + return x, w, w_shuffed, x_scales, w_scales, w_scales_shuffled, y def get_x_vals(): @@ -94,6 +115,7 @@ def get_x_vals(): x_vals += [(32, 512, 7168)] x_vals += [(1, 1280, 8192)] x_vals += [(v, 7168, 2048) for v in [1, 4, 8, 32, 64, 128]] + x_vals += [(v, 2112, 7168) for v in [1, 4, 8, 32, 64, 128]] # x_vals += [(1, 1, SCALE_GROUP_SIZE)] # minimal case, TODO: fix return x_vals @@ -147,9 +169,11 @@ def run_torch(x, w, w_scales, dtype): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) @pytest.mark.parametrize("layout", ["TN", "TT", "NN", "NT"]) @pytest.mark.parametrize("output", [True, False]) -@pytest.mark.parametrize("atomic_add", [True, False]) +@pytest.mark.parametrize( + "atomic_add, shuffle", [(True, False), (False, False), (False, True)] +) def test_gemm_a16wfp4( - M: int, N: int, K: int, dtype, layout, output: bool, atomic_add: bool + M: int, N: int, K: int, dtype, layout, output: bool, atomic_add: bool, shuffle: bool ): if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") @@ -160,16 +184,36 @@ def test_gemm_a16wfp4( if M == 4864 and N == 8192 and K == 4160: pytest.skip("Skipping this config. due to compilation error.") - x, w, _, w_scales, y = generate_gemm_a16wfp4_inputs( - M, N, K, output=output, atomic_add=atomic_add, dtype=dtype, layout=layout + x, w, w_triton, _, w_scales, w_scales_triton, y = generate_gemm_a16wfp4_inputs( + M, + N, + K, + output=output, + atomic_add=atomic_add, + dtype=dtype, + layout=layout, + shuffle=shuffle, ) y_dtype = torch.float32 if atomic_add else dtype - if output: - y = gemm_a16wfp4(x, w, w_scales, atomic_add=atomic_add, dtype=y_dtype, y=y).to( - dtype - ) + + if shuffle: + if output: + y = gemm_a16wfp4_preshuffle( + x, w_triton, w_scales_triton, prequant=True, dtype=y_dtype, y=y + ) + else: + y = gemm_a16wfp4_preshuffle( + x, w_triton, w_scales_triton, prequant=True, dtype=y_dtype + ) else: - y = gemm_a16wfp4(x, w, w_scales, atomic_add=atomic_add, dtype=y_dtype).to(dtype) + if output: + y = gemm_a16wfp4( + x, w_triton, w_scales_triton, atomic_add=atomic_add, dtype=y_dtype, y=y + ).to(dtype) + else: + y = gemm_a16wfp4( + x, w_triton, w_scales_triton, atomic_add=atomic_add, dtype=y_dtype + ).to(dtype) torch_out = run_torch(x, w, w_scales, dtype).to(dtype) From be2b403777f5e52e68cdf9d1b8d76a1d78acc3eb Mon Sep 17 00:00:00 2001 From: Shao-Chun Lee Date: Tue, 6 Jan 2026 09:02:52 -0600 Subject: [PATCH 02/10] Update aiter/ops/triton/gemm_a16wfp4.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- aiter/ops/triton/gemm/basic/gemm_a16wfp4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py b/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py index b8fa8313f9..06f0929a60 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py +++ b/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py @@ -252,7 +252,7 @@ def gemm_a16wfp4_preshuffle_( ) assert arch_info.is_fp4_avail(), "MXFP4 is not available on your device" - assert prequant == True, "prequant == False is not supported yet" + assert prequant, "prequant == False is not supported yet" M, K = x.shape N, K = w.shape From 11cd8de17cbec49683829fb840b68d18654adab0 Mon Sep 17 00:00:00 2001 From: Shao-Chun Lee Date: Tue, 6 Jan 2026 09:03:16 -0600 Subject: [PATCH 03/10] Update op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py b/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py index 3c021cec92..b63e8f0bc4 100644 --- a/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py +++ b/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py @@ -4,7 +4,6 @@ import aiter.ops.triton.utils._triton.arch_info as arch_info from op_tests.triton_tests.gemm.basic.test_gemm_afp4wfp4 import ( shuffle_scales, - un_shuffle_scales, ) from aiter.ops.shuffle import shuffle_weight From bc44fbeaa4bcf397808721a86741af137fe2117d Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 6 Jan 2026 15:05:24 +0000 Subject: [PATCH 04/10] format --- op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py b/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py index b63e8f0bc4..0f92b7f83e 100644 --- a/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py +++ b/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py @@ -2,9 +2,7 @@ import pytest from aiter.ops.triton.gemm.basic.gemm_a16wfp4 import gemm_a16wfp4, gemm_a16wfp4_preshuffle import aiter.ops.triton.utils._triton.arch_info as arch_info -from op_tests.triton_tests.gemm.basic.test_gemm_afp4wfp4 import ( - shuffle_scales, -) +from op_tests.triton_tests.gemm.basic.test_gemm_afp4wfp4 import shuffle_scales from aiter.ops.shuffle import shuffle_weight # Note this is specified by the HW and cannot be changed. From cb6a8168f10049adbabad0555382de65d9dec839 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 6 Jan 2026 15:08:55 +0000 Subject: [PATCH 05/10] black format --- aiter/ops/triton/gemm/basic/gemm_a16wfp4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py b/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py index 06f0929a60..78361c308f 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py +++ b/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py @@ -320,7 +320,7 @@ def gemm_a16wfp4_preshuffle_( **config, ) - if return_y_pp: + if return_y_pp: return y_pp elif config["NUM_KSPLIT"] > 1: REDUCE_BLOCK_SIZE_M = 16 From 771d1267214a89aeb5a5678f7ec4dc5ff52d5bd0 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 6 Jan 2026 16:22:17 +0000 Subject: [PATCH 06/10] clean --- aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py index 02f902447f..0528288add 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py @@ -138,8 +138,6 @@ def _gemm_a16wfp4_kernel( for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): b_scales = tl.load(b_scale_ptrs) - # a_scales = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K//SCALE_GROUP_SIZE), 127, dtype=tl.uint8) - # b_scales = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K//SCALE_GROUP_SIZE), 127, dtype=tl.uint8) # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. if EVEN_K: @@ -295,8 +293,6 @@ def _gemm_a16wfp4_preshuffle_kernel( offs_am[:, None] * stride_am + offs_k_split_bf16[None, :] * stride_ak ) - # offs_k = tl.arange(0, BLOCK_SIZE_K // 2) - # offs_k_split = pid_k * (SPLITK_BLOCK_SIZE // 2) + offs_k offs_k_shuffle_arr = tl.arange(0, (BLOCK_SIZE_K // 2) * 16) offs_k_shuffle = pid_k * (SPLITK_BLOCK_SIZE // 2) * 16 + offs_k_shuffle_arr offs_bn = (pid_n * (BLOCK_SIZE_N // 16) + tl.arange(0, BLOCK_SIZE_N // 16)) % N @@ -335,8 +331,6 @@ def _gemm_a16wfp4_preshuffle_kernel( .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // SCALE_GROUP_SIZE) ) - # a_scales = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K//SCALE_GROUP_SIZE), 127, dtype=tl.uint8) - # b_scales = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K//SCALE_GROUP_SIZE), 127, dtype=tl.uint8) # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. if EVEN_K: From 06b93e6256a52bcaa65f59d5e66b554d90e9b289 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 6 Jan 2026 19:19:25 +0000 Subject: [PATCH 07/10] fix --- aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py b/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py index bc4d7fd3d4..92ef460c4b 100644 --- a/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py +++ b/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py @@ -250,7 +250,7 @@ def fused_gemm_afp4wfp4_preshuffle_split_cat( assert N == D * (S1 + S2), "N is not D * (S1 + S2)" if config is None: - config = _get_config(M, N, K, True) + config, _ = _get_config(M, N, K, True) c1 = torch.empty((M, D, S1 + S3), dtype=dtype, device=x.device) c2 = torch.empty((M, D, S2), dtype=dtype, device=x.device) From 4d0e2645e9c629bc7016537da9ee2889ab1f8273 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Tue, 6 Jan 2026 23:01:45 +0000 Subject: [PATCH 08/10] update config --- ...MM-AFP4WFP4_PRESHUFFLED-N=2112-K=7168.json | 64 +++++++++---------- ...MM-AFP4WFP4_PRESHUFFLED-N=3072-K=1536.json | 42 ++++++------ 2 files changed, 53 insertions(+), 53 deletions(-) diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP4WFP4_PRESHUFFLED-N=2112-K=7168.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP4WFP4_PRESHUFFLED-N=2112-K=7168.json index 52d24f1bd6..2e4006bde2 100644 --- a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP4WFP4_PRESHUFFLED-N=2112-K=7168.json +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP4WFP4_PRESHUFFLED-N=2112-K=7168.json @@ -12,75 +12,75 @@ "NUM_KSPLIT": 7 }, "M_LEQ_16": { - "BLOCK_SIZE_M": 8, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 1024, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 4, "num_stages": 1, "waves_per_eu": 4, "matrix_instr_nonkdim": 16, - "cache_modifier": null, - "NUM_KSPLIT": 7 + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 }, "M_LEQ_32": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 1024, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 1, - "waves_per_eu": 6, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, - "cache_modifier": null, - "NUM_KSPLIT": 7 + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 }, "M_LEQ_64": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 1024, + "BLOCK_SIZE_K": 512, "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 1, + "num_warps": 4, + "num_stages": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16, - "cache_modifier": null, + "cache_modifier": ".cg", "NUM_KSPLIT": 7 }, "M_LEQ_128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 1024, "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, + "num_warps": 2, + "num_stages": 2, "waves_per_eu": 1, "matrix_instr_nonkdim": 16, "cache_modifier": null, - "NUM_KSPLIT": 7 + "NUM_KSPLIT": 1 }, "M_LEQ_256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 1024, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2, - "waves_per_eu": 1, + "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": null, - "NUM_KSPLIT": 7 + "NUM_KSPLIT": 1 }, "any": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 1024, + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 1, - "waves_per_eu": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", - "NUM_KSPLIT": 7 + "cache_modifier": null, + "NUM_KSPLIT": 1 } -} +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP4WFP4_PRESHUFFLED-N=3072-K=1536.json b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP4WFP4_PRESHUFFLED-N=3072-K=1536.json index f7a56a6f5f..edaa1172c8 100644 --- a/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP4WFP4_PRESHUFFLED-N=3072-K=1536.json +++ b/aiter/ops/triton/configs/gemm/gfx950-GEMM-AFP4WFP4_PRESHUFFLED-N=3072-K=1536.json @@ -1,19 +1,19 @@ { "M_LEQ_8": { - "BLOCK_SIZE_M": 4, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 8, - "num_warps": 8, + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 2, - "waves_per_eu": 6, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, - "cache_modifier": null, + "cache_modifier": ".cg", "NUM_KSPLIT": 1 }, "M_LEQ_16": { "BLOCK_SIZE_M": 8, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 512, "GROUP_SIZE_M": 1, "num_warps": 4, @@ -27,34 +27,34 @@ "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 512, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2, - "waves_per_eu": 2, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, - "cache_modifier": null, + "cache_modifier": ".cg", "NUM_KSPLIT": 1 }, "M_LEQ_64": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 512, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 2, "num_stages": 2, - "waves_per_eu": 4, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, - "cache_modifier": null, + "cache_modifier": ".cg", "NUM_KSPLIT": 1 }, "M_LEQ_128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 512, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2, - "waves_per_eu": 8, + "waves_per_eu": 1, "matrix_instr_nonkdim": 16, "cache_modifier": null, "NUM_KSPLIT": 1 @@ -83,4 +83,4 @@ "cache_modifier": null, "NUM_KSPLIT": 1 } -} +} \ No newline at end of file From 87071277bd28a43ad1d08db2198e420cde16e293 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Wed, 7 Jan 2026 00:06:20 +0000 Subject: [PATCH 09/10] fix api --- aiter/ops/triton/gemm/basic/gemm_a16wfp4.py | 4 ++- .../gemm/basic/test_gemm_a16wfp4.py | 35 ++++++++++++++++--- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py b/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py index 78361c308f..fcef64d6d6 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py +++ b/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py @@ -371,4 +371,6 @@ def gemm_a16wfp4_preshuffle( K = K // 16 config, _ = _get_config(M, N, K, True) config_hashable = serialize_dict(config) - return gemm_a16wfp4_preshuffle_(x, w, w_scales, prequant, dtype, y, config_hashable) + return gemm_a16wfp4_preshuffle_( + x, w, w_scales, prequant, dtype, y, config_hashable, skip_reduce + ) diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py b/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py index 0f92b7f83e..8ba9708a73 100644 --- a/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py +++ b/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py @@ -167,10 +167,24 @@ def run_torch(x, w, w_scales, dtype): @pytest.mark.parametrize("layout", ["TN", "TT", "NN", "NT"]) @pytest.mark.parametrize("output", [True, False]) @pytest.mark.parametrize( - "atomic_add, shuffle", [(True, False), (False, False), (False, True)] + "atomic_add, shuffle, skip_reduce", + [ + (True, False, False), + (False, False, False), + (False, True, False), + (False, True, True), + ], ) def test_gemm_a16wfp4( - M: int, N: int, K: int, dtype, layout, output: bool, atomic_add: bool, shuffle: bool + M: int, + N: int, + K: int, + dtype, + layout, + output: bool, + atomic_add: bool, + shuffle: bool, + skip_reduce: bool, ): if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") @@ -196,12 +210,25 @@ def test_gemm_a16wfp4( if shuffle: if output: y = gemm_a16wfp4_preshuffle( - x, w_triton, w_scales_triton, prequant=True, dtype=y_dtype, y=y + x, + w_triton, + w_scales_triton, + prequant=True, + dtype=y_dtype, + y=y, + skip_reduce=skip_reduce, ) else: y = gemm_a16wfp4_preshuffle( - x, w_triton, w_scales_triton, prequant=True, dtype=y_dtype + x, + w_triton, + w_scales_triton, + prequant=True, + dtype=y_dtype, + skip_reduce=skip_reduce, ) + if y.dim() == 3: + y = torch.sum(y, dim=0).to(dtype=dtype) else: if output: y = gemm_a16wfp4( From b79dc46c82f58dcf4e0cae6f207a1e0f198db27b Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Wed, 7 Jan 2026 15:20:19 +0000 Subject: [PATCH 10/10] black format --- op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py b/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py index 8ba9708a73..6075164918 100644 --- a/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py +++ b/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py @@ -1,6 +1,9 @@ import torch import pytest -from aiter.ops.triton.gemm.basic.gemm_a16wfp4 import gemm_a16wfp4, gemm_a16wfp4_preshuffle +from aiter.ops.triton.gemm.basic.gemm_a16wfp4 import ( + gemm_a16wfp4, + gemm_a16wfp4_preshuffle, +) import aiter.ops.triton.utils._triton.arch_info as arch_info from op_tests.triton_tests.gemm.basic.test_gemm_afp4wfp4 import shuffle_scales from aiter.ops.shuffle import shuffle_weight