diff --git a/aiter/ops/triton/_triton_kernels/fused_fp8_quant.py b/aiter/ops/triton/_triton_kernels/fused_fp8_quant.py index c84534f162..088b1ce415 100644 --- a/aiter/ops/triton/_triton_kernels/fused_fp8_quant.py +++ b/aiter/ops/triton/_triton_kernels/fused_fp8_quant.py @@ -194,3 +194,152 @@ def _fused_flatten_fp8_group_quant_kernel( out_block_scales.to(out_scales_ptr.dtype.element_ty), mask=block_scale_offs < tl.cdiv(N2, QUANT_BLOCK_SIZE), ) + + +@triton.jit +def _fused_reduce_act_mul_fp8_group_quant( + x_ptr, + y_ptr, + y_scale_ptr, + x2_ptr, + y2_ptr, + M, + N1, + N2, + stride_x_spk, + stride_x_m, + stride_x_n, + stride_y_m, + stride_y_n, + stride_y_scale_m, + stride_y_scale_n, + stride_x2_spk, + stride_x2_m, + stride_x2_n, + stride_y2_m, + stride_y2_n, + # Meta-parameters + ACTIVATION: tl.constexpr, + BLOCK_SIZE_M2: tl.constexpr, + BLOCK_SIZE_N1: tl.constexpr, + BLOCK_SIZE_N2: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, + DTYPE_MAX: tl.constexpr, + DTYPE_MIN: tl.constexpr, + X_HAS_SPLITK: tl.constexpr, + X_NUM_KSPLIT: tl.constexpr, + X_NUM_KSPLIT_POW2: tl.constexpr, + X_MASK: tl.constexpr, +): + + tl.assume(stride_x_spk > 0) + tl.assume(stride_x_m > 0) + tl.assume(stride_x_n > 0) + tl.assume(stride_y_m > 0) + tl.assume(stride_y_n > 0) + tl.assume(stride_y_scale_m > 0) + tl.assume(stride_y_scale_n > 0) + tl.assume(stride_x2_spk > 0) + tl.assume(stride_x2_m > 0) + tl.assume(stride_x2_n > 0) + tl.assume(stride_y2_m > 0) + tl.assume(stride_y2_n > 0) + + m_pid = tl.program_id(axis=0) + if X_HAS_SPLITK and m_pid >= M: + pid2 = m_pid - M + num_pid_n2 = tl.cdiv(N2, BLOCK_SIZE_N2) + pid_m2 = pid2 // num_pid_n2 + pid_n2 = pid2 % num_pid_n2 + offs_m2 = (pid_m2 * BLOCK_SIZE_M2 + tl.arange(0, BLOCK_SIZE_M2)) % M + offs_n2 = (pid_n2 * BLOCK_SIZE_N2 + tl.arange(0, BLOCK_SIZE_N2)) % N2 + offs_spk = tl.arange(0, X_NUM_KSPLIT_POW2) + x2_ptrs = ( + x2_ptr + + offs_spk[:, None, None] * stride_x2_spk + + offs_m2[None, :, None] * stride_x2_m + + offs_n2[None, None, :] * stride_x2_n + ) + if X_NUM_KSPLIT_POW2 == X_NUM_KSPLIT: + x2 = tl.load(x2_ptrs) + else: + x2 = tl.load( + x2_ptrs, mask=offs_spk[:, None, None] < X_NUM_KSPLIT, other=0.0 + ) + x2 = tl.sum(x2, axis=0) + + x2 = x2.to(y2_ptr.type.element_ty) + + y2_out_ptrs = ( + y2_ptr + (offs_m2[:, None] * stride_y2_m) + (offs_n2[None, :] * stride_y2_n) + ) + + tl.store(y2_out_ptrs, x2) + return + + n_offs = tl.arange(0, BLOCK_SIZE_N1) + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N1 // QUANT_BLOCK_SIZE + + mask = None + other = None + if X_HAS_SPLITK: + offs_spk = tl.arange(0, X_NUM_KSPLIT_POW2) + x_ptrs = ( + x_ptr + + offs_spk[:, None] * stride_x_spk + + m_pid * stride_x_m + + n_offs[None, :] * stride_x_n + ) + if X_MASK: + mask = (offs_spk[:, None] < X_NUM_KSPLIT) & (n_offs[None, :] < N1) + other = 0.0 + else: + mask = offs_spk[:, None] < X_NUM_KSPLIT + other = 0.0 + else: + x_ptrs = x_ptr + m_pid * stride_x_m + n_offs * stride_x_n + if X_MASK: + mask = n_offs < N1 + other = 0.0 + + x = tl.load( + x_ptrs, + mask=mask, + other=other, + cache_modifier=".cg", + ).to(tl.float32) + x_mul = tl.load( + x_ptrs + N1 * stride_x_n, + mask=mask, + other=other, + cache_modifier=".cg", + ).to(tl.float32) + + if X_HAS_SPLITK: + x = tl.sum(x, axis=0) + x_mul = tl.sum(x_mul, axis=0) + + x = ACTIVATION(x) * x_mul + + y, y_scale = _fp8_quant_op( + x, 1, BLOCK_SIZE_N1, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN + ) + y = tl.ravel(y) + y_scale = tl.ravel(y_scale) + + if X_MASK: + mask = n_offs < N1 + else: + mask = n_offs < N1 + tl.store( + y_ptr + m_pid * stride_y_m + n_offs * stride_y_n, + y.to(y_ptr.dtype.element_ty), + mask=mask, + ) + g_offs = tl.arange(0, NUM_QUANT_BLOCKS) + num_bs_cols = (N1 + QUANT_BLOCK_SIZE - 1) // QUANT_BLOCK_SIZE + tl.store( + y_scale_ptr + m_pid * stride_y_scale_m + g_offs * stride_y_scale_n, + y_scale.to(y_scale_ptr.dtype.element_ty), + mask=g_offs < num_bs_cols, + ) diff --git a/aiter/ops/triton/_triton_kernels/fused_gemm_a8w8_blockscale_a16w16.py b/aiter/ops/triton/_triton_kernels/fused_gemm_a8w8_blockscale_a16w16.py new file mode 100644 index 0000000000..6ac5b5ee46 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/fused_gemm_a8w8_blockscale_a16w16.py @@ -0,0 +1,452 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional +import functools +import json +import os +import torch +import triton +import triton.language as tl +from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd +from ..utils._triton import arch_info +from ..utils.core import AITER_TRITON_CONFIGS_PATH + + +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, + "GRID_MN_FP8": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"]) + * triton.cdiv(args["N_fp8"], args["BLOCK_SIZE_N"]), + "GRID_MN_BF16": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"]) + * triton.cdiv(args["N_bf16"], args["BLOCK_SIZE_N"]), + } +) +@triton.jit +def _fused_gemm_a8w8_blockscale_a16w16_kernel( + # Pointers to matrices + a_fp8_ptr, + b_fp8_ptr, + bias_fp8_ptr, + a_fp8_scale_ptr, + b_fp8_scale_ptr, + c_fp8_ptr, + a_bf16_ptr, + b_bf16_ptr, + bias_bf16_ptr, + c_bf16_ptr, + # Matrix dimensions + M, + N_fp8, + N_bf16, + K, + stride_a_fp8_m, + stride_a_fp8_k, + stride_b_fp8_k, + stride_b_fp8_n, + stride_a_fp8_scale_m, + stride_a_fp8_scale_k, + stride_b_fp8_scale_k, + stride_b_fp8_scale_n, + stride_c_fp8_k, + stride_c_fp8_m, + stride_c_fp8_n, + stride_a_bf16_m, + stride_a_bf16_k, + stride_b_bf16_k, + stride_b_bf16_n, + stride_c_bf16_k, + stride_c_bf16_m, + stride_c_bf16_n, + # Meta-parameters + GROUP_K: tl.constexpr, + GROUP_N: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, + ADD_BIAS_FP8: tl.constexpr, + ADD_BIAS_BF16: tl.constexpr, + EVEN_K: tl.constexpr, + GRID_MN_FP8: tl.constexpr, + GRID_MN_BF16: tl.constexpr, + SKIP_REDUCE: tl.constexpr, + cache_modifier: tl.constexpr, +): + + tl.assume(stride_a_fp8_m > 0) + tl.assume(stride_a_fp8_k > 0) + tl.assume(stride_b_fp8_k > 0) + tl.assume(stride_b_fp8_n > 0) + tl.assume(stride_c_fp8_k > 0) + tl.assume(stride_c_fp8_m > 0) + tl.assume(stride_c_fp8_n > 0) + tl.assume(stride_a_fp8_scale_m > 0) + tl.assume(stride_a_fp8_scale_k > 0) + tl.assume(stride_b_fp8_scale_k > 0) + tl.assume(stride_b_fp8_scale_n > 0) + + tl.assume(stride_a_bf16_m > 0) + tl.assume(stride_a_bf16_k > 0) + tl.assume(stride_b_bf16_k > 0) + tl.assume(stride_b_bf16_n > 0) + tl.assume(stride_c_bf16_m > 0) + tl.assume(stride_c_bf16_n > 0) + + pid_unified = tl.program_id(axis=0) + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n_fp8 = tl.cdiv(N_fp8, BLOCK_SIZE_N) + num_pid_n_bf16 = tl.cdiv(N_bf16, BLOCK_SIZE_N) + num_pid_n = num_pid_n_fp8 + num_pid_n_bf16 + + if NUM_KSPLIT == 1: + GRID_MN: tl.constexpr = GRID_MN_FP8 + GRID_MN_BF16 + remap_xcd(pid, GRID_MN) + + pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(pid_k >= 0) + + if (pid_k * SPLITK_BLOCK_SIZE) < K: + + num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE, BLOCK_SIZE_K) + acc_dtype = tl.float32 if c_fp8_ptr.type.element_ty != tl.int8 else tl.int32 + + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k_split = pid_k * SPLITK_BLOCK_SIZE + offs_k + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_ks_step = BLOCK_SIZE_K // GROUP_K + + if pid_n < num_pid_n_fp8: + offs_b_fp8_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N_fp8 + a_fp8_ptrs = a_fp8_ptr + ( + offs_am[:, None] * stride_a_fp8_m + + offs_k_split[None, :] * stride_a_fp8_k + ) + b_fp8_ptrs = b_fp8_ptr + ( + offs_k_split[:, None] * stride_b_fp8_k + + offs_b_fp8_n[None, :] * stride_b_fp8_n + ) + + offs_ks = (pid_k * SPLITK_BLOCK_SIZE) // GROUP_K + a_scale_ptrs = ( + a_fp8_scale_ptr + + offs_am * stride_a_fp8_scale_m + + offs_ks * stride_a_fp8_scale_k + ) + offs_bsn = offs_b_fp8_n // GROUP_N + b_scale_ptrs = ( + b_fp8_scale_ptr + + offs_ks * stride_b_fp8_scale_k + + offs_bsn * stride_b_fp8_scale_n + ) + + if ADD_BIAS_FP8: + if NUM_KSPLIT == 1 or (SKIP_REDUCE and pid_k == 0): + accumulator_fp8 = tl.load(bias_fp8_ptr + offs_b_fp8_n).to( + dtype=acc_dtype + ) + accumulator_fp8 = tl.broadcast_to( + accumulator_fp8[None, :], (BLOCK_SIZE_M, BLOCK_SIZE_N) + ) + else: + accumulator_fp8 = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype + ) + else: + accumulator_fp8 = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype + ) + + for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): + if EVEN_K: + a = tl.load(a_fp8_ptrs) + b = tl.load(b_fp8_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_fp8_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_fp8_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0, + cache_modifier=cache_modifier, + ) + + a_scale = tl.load(a_scale_ptrs) + b_scale = tl.load(b_scale_ptrs) + + accumulator_fp8 += ( + tl.dot(a, b, input_precision="ieee") + * a_scale[:, None] + * b_scale[None, :] + ) + + a_fp8_ptrs += BLOCK_SIZE_K * stride_a_fp8_k + b_fp8_ptrs += BLOCK_SIZE_K * stride_b_fp8_k + a_scale_ptrs += offs_ks_step * stride_a_fp8_scale_k + b_scale_ptrs += offs_ks_step * stride_b_fp8_scale_k + + c_fp8 = accumulator_fp8.to(c_fp8_ptr.type.element_ty) + + offs_cm = pid_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( + tl.int64 + ) + offs_c_fp8_n = pid_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ).to(tl.int64) + c_fp8_ptrs = ( + c_fp8_ptr + + stride_c_fp8_m * offs_cm[:, None] + + stride_c_fp8_n * offs_c_fp8_n[None, :] + + pid_k * stride_c_fp8_k + ) + c_fp8_mask = (offs_cm[:, None] < M) & (offs_c_fp8_n[None, :] < N_fp8) + tl.store(c_fp8_ptrs, c_fp8, mask=c_fp8_mask) + else: + pid_n -= num_pid_n_fp8 + + offs_b_bf16_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N_bf16 + a_ptrs = a_bf16_ptr + ( + offs_am[:, None] * stride_a_bf16_m + + offs_k_split[None, :] * stride_a_bf16_k + ) + b_ptrs = b_bf16_ptr + ( + offs_k_split[:, None] * stride_b_bf16_k + + offs_b_bf16_n[None, :] * stride_b_bf16_n + ) + + if ADD_BIAS_BF16: + if NUM_KSPLIT == 1 or (SKIP_REDUCE and pid_k == 0): + accumulator_bf16 = tl.load(bias_bf16_ptr + offs_b_bf16_n).to( + dtype=acc_dtype + ) + accumulator_bf16 = tl.broadcast_to( + accumulator_bf16[None, :], (BLOCK_SIZE_M, BLOCK_SIZE_N) + ) + else: + accumulator_bf16 = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype + ) + else: + accumulator_bf16 = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype + ) + + for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0, + cache_modifier=cache_modifier, + ) + + accumulator_bf16 += tl.dot(a, b, input_precision="ieee") + + a_ptrs += BLOCK_SIZE_K * stride_a_bf16_k + b_ptrs += BLOCK_SIZE_K * stride_b_bf16_k + + c_bf16 = accumulator_bf16.to(c_bf16_ptr.type.element_ty) + + # Write back the block of the output matrix C with masks. + offs_cm = pid_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( + tl.int64 + ) + offs_c_bf16_n = pid_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ).to(tl.int64) + c_bf16_ptrs = ( + c_bf16_ptr + + stride_c_bf16_m * offs_cm[:, None] + + stride_c_bf16_n * offs_c_bf16_n[None, :] + + pid_k * stride_c_bf16_k + ) + c_bf16_mask = (offs_cm[:, None] < M) & (offs_c_bf16_n[None, :] < N_bf16) + tl.store(c_bf16_ptrs, c_bf16, mask=c_bf16_mask) + + +@triton.jit +def _fused_gemm_a8w8_blockscale_a16w16_reduce_kernel( + bias_fp8_ptr, + c_fp8_in_ptr, + c_fp8_out_ptr, + bias_bf16_ptr, + c_bf16_in_ptr, + c_bf16_out_ptr, + M, + N_fp8, + N_bf16, + stride_c_fp8_in_k, + stride_c_fp8_in_m, + stride_c_fp8_in_n, + stride_c_fp8_out_m, + stride_c_fp8_out_n, + stride_c_bf16_in_k, + stride_c_bf16_in_m, + stride_c_bf16_in_n, + stride_c_bf16_out_m, + stride_c_bf16_out_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + ACTUAL_KSPLIT: tl.constexpr, + MAX_KSPLIT: tl.constexpr, + ADD_BIAS_FP8: tl.constexpr, + ADD_BIAS_BF16: tl.constexpr, +): + + tl.assume(stride_c_fp8_in_k > 0) + tl.assume(stride_c_fp8_in_m > 0) + tl.assume(stride_c_fp8_in_n > 0) + tl.assume(stride_c_fp8_out_m > 0) + tl.assume(stride_c_fp8_out_n > 0) + + tl.assume(stride_c_bf16_in_k > 0) + tl.assume(stride_c_bf16_in_m > 0) + tl.assume(stride_c_bf16_in_n > 0) + tl.assume(stride_c_bf16_out_m > 0) + tl.assume(stride_c_bf16_out_n > 0) + + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + + num_pid_n_fp8 = tl.cdiv(N_fp8, BLOCK_SIZE_N) + offs_k = tl.arange(0, MAX_KSPLIT) + acc_dtype = tl.float32 if c_fp8_in_ptr.type.element_ty != tl.int8 else tl.int32 + + if pid_n < num_pid_n_fp8: + offs_fp8_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N_fp8 + c_fp8_in_ptrs = ( + c_fp8_in_ptr + + (offs_k[:, None, None] * stride_c_fp8_in_k) + + (offs_m[None, :, None] * stride_c_fp8_in_m) + + (offs_fp8_n[None, None, :] * stride_c_fp8_in_n) + ) + + if ACTUAL_KSPLIT == MAX_KSPLIT: + c_fp8 = tl.load(c_fp8_in_ptrs) + else: + c_fp8 = tl.load( + c_fp8_in_ptrs, mask=offs_k[:, None, None] < ACTUAL_KSPLIT, other=0.0 + ) + c_fp8 = tl.sum(c_fp8, axis=0) + if ADD_BIAS_FP8: + bias_fp8 = tl.load(bias_fp8_ptr + offs_fp8_n).to(dtype=acc_dtype) + bias_fp8 = tl.broadcast_to(bias_fp8[None, :], (BLOCK_SIZE_M, BLOCK_SIZE_N)) + c_fp8 += bias_fp8 + + c_fp8 = c_fp8.to(c_fp8_out_ptr.type.element_ty) + + c_fp8_out_ptrs = ( + c_fp8_out_ptr + + (offs_m[:, None] * stride_c_fp8_out_m) + + (offs_fp8_n[None, :] * stride_c_fp8_out_n) + ) + + tl.store(c_fp8_out_ptrs, c_fp8) + else: + pid_n -= num_pid_n_fp8 + + offs_bf16_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N_bf16 + c_bf16_in_ptrs = ( + c_bf16_in_ptr + + (offs_k[:, None, None] * stride_c_bf16_in_k) + + (offs_m[None, :, None] * stride_c_bf16_in_m) + + (offs_bf16_n[None, None, :] * stride_c_bf16_in_n) + ) + + if ACTUAL_KSPLIT == MAX_KSPLIT: + c_bf16 = tl.load(c_bf16_in_ptrs) + else: + c_bf16 = tl.load( + c_bf16_in_ptrs, mask=offs_k[:, None, None] < ACTUAL_KSPLIT, other=0.0 + ) + c_bf16 = tl.sum(c_bf16, axis=0) + if ADD_BIAS_BF16: + bias_bf16 = tl.load(bias_bf16_ptr + offs_bf16_n).to(dtype=acc_dtype) + bias_bf16 = tl.broadcast_to( + bias_bf16[None, :], (BLOCK_SIZE_M, BLOCK_SIZE_N) + ) + c_bf16 += bias_bf16 + + c_bf16 = c_bf16.to(c_bf16_out_ptr.type.element_ty) + + c_bf16_out_ptrs = ( + c_bf16_out_ptr + + (offs_m[:, None] * stride_c_bf16_out_m) + + (offs_bf16_n[None, :] * stride_c_bf16_out_n) + ) + c_bf16_mask = (offs_m[:, None] < M) & (offs_bf16_n[None, :] < N_bf16) + tl.store(c_bf16_out_ptrs, c_bf16, mask=c_bf16_mask) + + +@functools.lru_cache(maxsize=1024) +def _get_config( + M: int, + N_fp8: int, + N_bf16: int, + K: int, +): + if not hasattr(_get_config, "_config_dict"): + dev = arch_info.get_device() + _get_config._config_dict = {} + fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json" + with open(fpath, "r") as file: + config = json.load(file) + _get_config._config_dict["default"] = config + + key = f"{N_fp8}_{N_bf16}_{K}" + if key not in _get_config._config_dict.keys(): + dev = arch_info.get_device() + fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16-N8={N_fp8}-N16={N_bf16}-K={K}.json" + if os.path.exists(fpath): + with open(fpath, "r") as file: + config = json.load(file) + _get_config._config_dict[key] = config + else: + key = "default" # fall back to default config + + if M < 16 and "small" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["small"] + elif M < 32 and "small_M16" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["small_M16"] + elif M <= 128: + BLK_M = triton.next_power_of_2(M) + if BLK_M == 32 and "medium_M32" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["medium_M32"] + elif BLK_M == 64 and "medium_M64" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["medium_M64"] + elif BLK_M == 128 and "medium_M128" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["medium_M128"] + elif M <= 256 and "large" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["large"] + else: + BLK_M = triton.next_power_of_2(M) + if f"xlarge_M{BLK_M}" in _get_config._config_dict[key]: + return _get_config._config_dict[key][f"xlarge_M{BLK_M}"] + elif "xlarge" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["xlarge"] + + return _get_config._config_dict[key]["any"] diff --git a/aiter/ops/triton/configs/gemm/MI350X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16-N8=512-N16=256-K=7168.json b/aiter/ops/triton/configs/gemm/MI350X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16-N8=512-N16=256-K=7168.json new file mode 100644 index 0000000000..8658e29031 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16-N8=512-N16=256-K=7168.json @@ -0,0 +1,110 @@ +{ + "small": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M32": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "xlarge_M1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "xlarge_M2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json b/aiter/ops/triton/configs/gemm/MI350X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json new file mode 100644 index 0000000000..b1a39b0746 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/fused_fp8_quant.py b/aiter/ops/triton/fused_fp8_quant.py index 652c920106..46c2646d74 100644 --- a/aiter/ops/triton/fused_fp8_quant.py +++ b/aiter/ops/triton/fused_fp8_quant.py @@ -1,10 +1,19 @@ +from typing import Optional import torch import triton import aiter from aiter.ops.triton._triton_kernels.fused_fp8_quant import ( _fused_rms_fp8_group_quant_kernel, _fused_flatten_fp8_group_quant_kernel, + _fused_reduce_act_mul_fp8_group_quant, ) +from aiter.ops.triton._triton_kernels.activation import ( + _get_activation_from_str, +) +from aiter.ops.triton.utils.logger import AiterTritonLogger + +_LOGGER = AiterTritonLogger() + fp8_dtype = aiter.dtypes.fp8 @@ -148,6 +157,16 @@ def fused_flatten_fp8_group_quant( group_size, dtype_quant=fp8_dtype, ): + """ + Flatten the last two dimension of x and perform FP8 per-token group quantization along the last dimension + + Key parameters: + - x: Matrix X with shape (M, N1, N2). + + Returns: + - out: The output matrix with shape (M, N1 * N2). + - out_block_scales: The output matrix with shape (M, cdiv((N1 * N2), group_size)). + """ M, N1, N2 = x.shape BLOCK_SIZE_N2 = max(triton.next_power_of_2(N2), group_size) @@ -181,3 +200,127 @@ def fused_flatten_fp8_group_quant( ) return out, out_block_scales + + +def fused_reduce_act_mul_fp8_group_quant( + x: torch.Tensor, + activation: str = "silu", + x2: Optional[torch.Tensor] = None, + group_size=128, + dtype_quant=fp8_dtype, + dtype: Optional[float] = torch.bfloat16, +): + """ + Apply reduction along the first dimension and apply the activation function + per-token group quantization. + If x2 is provided, the only reduction along the first dimension is applied to x2 + + Args: + if x is 3-dim, + x: (SPK, M, 2*N1), dtype = fp32. + x2: (SPK, M, 2*N1), dtype = fp32. + + if x is 2-dim, + x: (M, N2), dtype = fp16 or bf16. + x2 must be None + the kernel is essentially identical to aiter.ops.triton.activation.act_mul_and_fp8_group_quant + + activation: activation function to apply before quantization. + - It splits the features into two parts and applies the activation to the first part. + - Then, it adds the results together before quantization. + - Supports the following activations: + - "silu" + - "gelu" + - "gelu_tanh" + + Returns: + tuple: (y, y_scale), y2 + y: (M, N1), dtype = dtype_quant + y_scale: (M, cdiv(N1, group_size)), dtype = fp32 + y2: (M, N2), dtype = dtype + """ + _LOGGER.info(f"FUSED_REDUCTION_ACT_MUL_FP8_GROUP_QUANT: x={tuple(x.shape)}") + + assert ( + x.dim() == 2 or x.dim() == 3 + ), "The number of dimentions for x should be 2 or 3" + X_HAS_SPLITK = False + x_num_splitk = 1 + N2 = 1 + y2 = None + if x.dim() == 3: + x_num_splitk, M, N1 = x.shape + x_num_splitk, _, N2 = x2.shape + assert ( + x.shape[0] == x2.shape[0] and x.shape[1] == x2.shape[1] + ), "The first two dimensions should be identical between x and x2" + assert ( + x_num_splitk > 1 + ), "x.shape[0] should be larger then 1 in x.dim() == 3 cases" + X_HAS_SPLITK = True + y2 = torch.empty((M, N2), dtype=dtype, device=x2.device) + else: + M, N1 = x.shape + assert x2 is None, "x2 should be None in x.dim() == 2 cases" + + assert ( + N1 % 2 == 0 + ), "The last dimension for x1 should be multiple of 2 for acitvation and multiplication" + N1 = N1 // 2 + + y = torch.empty((M, N1), dtype=dtype_quant, device=x.device) + y_scale = torch.empty( + (M, (N1 + group_size - 1) // group_size), + dtype=torch.float32, + device=x.device, + ) + + BLOCK_SIZE_N1 = max(triton.next_power_of_2(N1), group_size) + BLOCK_SIZE_N2 = max(triton.next_power_of_2(N2), 32) + BLOCK_SIZE_M2 = 1 if M <= 128 else 4 + X_MASK = N1 % BLOCK_SIZE_N1 != 0 + + DTYPE_MAX = ( + torch.finfo(y.dtype).max + if torch.is_floating_point(y) + else torch.iinfo(y.dtype).max + ) + num_pid = M + if X_HAS_SPLITK: + num_pid += triton.cdiv(M, BLOCK_SIZE_M2) * triton.cdiv(N2, BLOCK_SIZE_N2) + grid = (num_pid,) + _fused_reduce_act_mul_fp8_group_quant[grid]( + x, + y, + y_scale, + x2, + y2, + M, + N1, + N2, + 0 if not X_HAS_SPLITK else x.stride(0), + x.stride(0) if not X_HAS_SPLITK else x.stride(1), + x.stride(1) if not X_HAS_SPLITK else x.stride(2), + y.stride(0), + y.stride(1), + y_scale.stride(0), + y_scale.stride(1), + 0 if not X_HAS_SPLITK else x2.stride(0), + 0 if not X_HAS_SPLITK else x2.stride(1), + 0 if not X_HAS_SPLITK else x2.stride(2), + 0 if not X_HAS_SPLITK else y2.stride(0), + 0 if not X_HAS_SPLITK else y2.stride(1), + ACTIVATION=_get_activation_from_str(activation) if activation else "", + BLOCK_SIZE_M2=BLOCK_SIZE_M2, + BLOCK_SIZE_N1=BLOCK_SIZE_N1, + BLOCK_SIZE_N2=BLOCK_SIZE_N2, + QUANT_BLOCK_SIZE=group_size, + DTYPE_MAX=DTYPE_MAX, + DTYPE_MIN=-DTYPE_MAX, + X_HAS_SPLITK=X_HAS_SPLITK, + X_NUM_KSPLIT=x_num_splitk, + X_NUM_KSPLIT_POW2=triton.next_power_of_2(x_num_splitk), + X_MASK=X_MASK, + num_warps=1 if max(BLOCK_SIZE_N1, BLOCK_SIZE_N2) <= 512 else 4, + ) + + return (y, y_scale), y2 diff --git a/aiter/ops/triton/fused_gemm_a8w8_blockscale_a16w16.py b/aiter/ops/triton/fused_gemm_a8w8_blockscale_a16w16.py new file mode 100644 index 0000000000..40f5f62633 --- /dev/null +++ b/aiter/ops/triton/fused_gemm_a8w8_blockscale_a16w16.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional +import functools +import json +import os +import torch +import triton +import triton.language as tl +from aiter.ops.triton._triton_kernels.fused_gemm_a8w8_blockscale_a16w16 import ( + _fused_gemm_a8w8_blockscale_a16w16_kernel, + _fused_gemm_a8w8_blockscale_a16w16_reduce_kernel, + _get_config, +) +from aiter.ops.triton.utils.logger import AiterTritonLogger + +_LOGGER = AiterTritonLogger() + + +def fused_gemm_a8w8_blockscale_a16w16( + x_fp8: torch.Tensor, + w_fp8: torch.Tensor, + x_fp8_scale: torch.Tensor, + w_fp8_scale: torch.Tensor, + x_bf16: torch.Tensor, + w_bf16: torch.Tensor, + bias_fp8: Optional[torch.Tensor] = None, + bias_bf16: Optional[torch.Tensor] = None, + dtype: Optional[float] = torch.bfloat16, + y_fp8: Optional[torch.Tensor] = None, + y_bf16: Optional[torch.Tensor] = None, + skip_reduce: Optional[bool] = False, + config: Optional[dict] = None, +): + """ + Computes the 8 bit matmul Y = X x WT + B using the block-scale quantization approach for x_fp8 and w_fp8. + Computes the 16 bit matmul Y = X x WT + B for x_bf16 and w_bf16 + + This fusion is primarily aiming for fusing the gate up-projections and MOE gating: + gate up-projections: (M, K) x (2N, K) = (M, 2N) + MOE gating: (M, K) x (N, K) + (N, ) = (M, N) + + Key parameters: + - x_fp8: Matrix X with shape (M, K). + - w_fp8: Matrix W with shape (N_fp8, K). + - x_fp8_scale: Scale tensor for X with shape (M, *scale_k). + - w_fp8_scale: Scale tensor for W with shape (**scale_n, *scale_k). + - x_bf16: Matrix X with shape (M, K). + - w_bf16: Matrix W with shape (N_fp8, K). + + Note: M, N, K must be identical for x_fp8 and x_bf16, but the N-dim fow w_fp8 and w_bf16 can be different + + Returns: + - Y: The output matrix with shape (M, N). + + *scale_k = (K + scale_block_size_k - 1) // scale_block_size_k + **scale_n = (N_fp8 + scale_block_size_n - 1) // scale_block_size_n + """ + _LOGGER.info( + f"FUSED_GEMM_A8W8_BLOCKSCALE_A16W16: x_fp8={tuple(x_fp8.shape)} w_fp8={tuple(w_fp8.shape)} x_fp8_scale={tuple(x_fp8_scale.shape)} w_scale={tuple(w_fp8_scale.shape)} x_bf16={tuple(x_bf16.shape)} w_bf16={tuple(w_bf16.shape)}" + ) + + M, K = x_fp8.shape + N_fp8, K = w_fp8.shape + M, K = x_bf16.shape + N_bf16, K = w_bf16.shape + + # Check constraints. + assert ( + x_fp8.shape[0] == x_bf16.shape[0] + ), "M-dim should be identical for x_fp8 and x_bf16" + assert ( + x_fp8.shape[1] == x_bf16.shape[1] + ), "K-dim should be identical for x_fp8 and x_bf16" + assert x_fp8.shape[1] == w_fp8.shape[1], "Incompatible dimensions!!!" + assert w_bf16.shape[1] == w_bf16.shape[1], "Incompatible dimensions!!!" + + # Transpose w and w_scale + w_fp8 = w_fp8.T + w_bf16 = w_bf16.T + w_fp8_scale = w_fp8_scale.T + + if config is None: + config = _get_config(M, N_fp8, N_bf16, K) + + if y_fp8 is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): + y_fp8 = torch.empty((M, N_fp8), dtype=dtype, device=x_fp8.device) + + if y_bf16 is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): + y_bf16 = torch.empty((M, N_bf16), dtype=dtype, device=x_bf16.device) + + config["SPLITK_BLOCK_SIZE"] = triton.cdiv(K, config["NUM_KSPLIT"]) + if config["NUM_KSPLIT"] > 1: + y_fp8_pp = torch.empty( + (config["NUM_KSPLIT"], M, N_fp8), + dtype=torch.float32, + device=y_fp8.device if y_fp8 is not None else x_fp8.device, + ) + y_bf16_pp = torch.empty( + (config["NUM_KSPLIT"], M, N_bf16), + dtype=torch.float32, + device=y_bf16.device if y_bf16 is not None else x_bf16.device, + ) + else: + y_fp8_pp = None + y_bf16_pp = None + + if config["BLOCK_SIZE_K"] > config["SPLITK_BLOCK_SIZE"]: + config["BLOCK_SIZE_K"] = triton.next_power_of_2(config["SPLITK_BLOCK_SIZE"]) + if config["BLOCK_SIZE_K"] > config["SPLITK_BLOCK_SIZE"]: + config["BLOCK_SIZE_K"] = config["BLOCK_SIZE_K"] // 4 + config["BLOCK_SIZE_K"] = max(config["BLOCK_SIZE_K"], 16) + + # Scale block sizes + # TODO: need a better way to pass scale block sizes around + config["GROUP_K"] = triton.next_power_of_2(triton.cdiv(K, w_fp8_scale.shape[0])) + config["GROUP_N"] = triton.next_power_of_2(triton.cdiv(N_fp8, w_fp8_scale.shape[1])) + + # grid = (config["NUM_KSPLIT"], triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(N, config["BLOCK_SIZE_N"]),) + grid = lambda META: ( # noqa: E731 + ( + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * ( + triton.cdiv(N_fp8, META["BLOCK_SIZE_N"]) + + triton.cdiv(N_bf16, META["BLOCK_SIZE_N"]) + ) + ), + ) + _fused_gemm_a8w8_blockscale_a16w16_kernel[grid]( + x_fp8, + w_fp8, + bias_fp8, + x_fp8_scale, + w_fp8_scale, + y_fp8 if config["NUM_KSPLIT"] == 1 else y_fp8_pp, + x_bf16, + w_bf16, + bias_bf16, + y_bf16 if config["NUM_KSPLIT"] == 1 else y_bf16_pp, + M, + N_fp8, + N_bf16, + K, + x_fp8.stride(0), + x_fp8.stride(1), + w_fp8.stride(0), + w_fp8.stride(1), + x_fp8_scale.stride(0), + x_fp8_scale.stride(1), + w_fp8_scale.stride(0), + w_fp8_scale.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_fp8_pp.stride(0), + y_fp8.stride(0) if config["NUM_KSPLIT"] == 1 else y_fp8_pp.stride(1), + y_fp8.stride(1) if config["NUM_KSPLIT"] == 1 else y_fp8_pp.stride(2), + x_bf16.stride(0), + x_bf16.stride(1), + w_bf16.stride(0), + w_bf16.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_bf16_pp.stride(0), + y_bf16.stride(0) if config["NUM_KSPLIT"] == 1 else y_bf16_pp.stride(1), + y_bf16.stride(1) if config["NUM_KSPLIT"] == 1 else y_bf16_pp.stride(2), + ADD_BIAS_FP8=(bias_fp8 is not None), + ADD_BIAS_BF16=(bias_bf16 is not None), + SKIP_REDUCE=skip_reduce, + **config, + ) + + if config["NUM_KSPLIT"] > 1: + if skip_reduce: + return y_fp8_pp, y_bf16_pp + REDUCE_BLOCK_SIZE_M = 32 + REDUCE_BLOCK_SIZE_N = 32 + ACTUAL_KSPLIT = triton.cdiv(K, config["SPLITK_BLOCK_SIZE"]) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N_fp8, REDUCE_BLOCK_SIZE_N) + + triton.cdiv(N_bf16, REDUCE_BLOCK_SIZE_N), + ) + _fused_gemm_a8w8_blockscale_a16w16_reduce_kernel[grid_reduce]( + bias_fp8, + y_fp8_pp, + y_fp8, + bias_bf16, + y_bf16_pp, + y_bf16, + M, + N_fp8, + N_bf16, + y_fp8_pp.stride(0), + y_fp8_pp.stride(1), + y_fp8_pp.stride(2), + y_fp8.stride(0), + y_fp8.stride(1), + y_bf16_pp.stride(0), + y_bf16_pp.stride(1), + y_bf16_pp.stride(2), + y_bf16.stride(0), + y_bf16.stride(1), + REDUCE_BLOCK_SIZE_M, + REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT, + triton.next_power_of_2(config["NUM_KSPLIT"]), + ADD_BIAS_FP8=(bias_fp8 is not None), + ADD_BIAS_BF16=(bias_bf16 is not None), + ) + + return y_fp8, y_bf16 diff --git a/op_tests/triton_tests/test_fused_fp8_quant.py b/op_tests/triton_tests/test_fused_fp8_quant.py new file mode 100644 index 0000000000..9621756f4f --- /dev/null +++ b/op_tests/triton_tests/test_fused_fp8_quant.py @@ -0,0 +1,219 @@ +import torch +import pytest +from aiter.ops.triton.fused_fp8_quant import ( + fused_rms_fp8_group_quant, + fused_flatten_fp8_group_quant, + fused_reduce_act_mul_fp8_group_quant, +) +from op_tests.triton_tests.test_quant_mxfp4 import torch_dynamic_mxfp4_quant +import aiter +import torch.nn.functional as F + +torch.manual_seed(0) + + +def rmsnorm(input, weight, eps=1e-6): + row_norm = input * input + row_norm = torch.sum(row_norm, dim=-1) + norm_factor = torch.rsqrt((row_norm / input.shape[1]) + eps) + rms_norm = input * norm_factor[:, None] * weight[None, :] + return rms_norm + + +def per_token_fp8_group_quant(x, dtype_quant, group_size=128): + DTYPE_MAX = torch.finfo(dtype_quant).max + M, N = x.shape + x_reshape = x.reshape(M, N // group_size, group_size).to(torch.float32) + x_max = torch.max(torch.abs(x_reshape), dim=-1, keepdim=True)[0] + x_max = torch.where(x_max < 1e-10, 1e-10, x_max).to(torch.float32) + x_scale = x_max / DTYPE_MAX + scale_recip = 1.0 / x_scale + x_quant = torch.clamp(x_reshape * scale_recip, -DTYPE_MAX, DTYPE_MAX).to( + dtype_quant + ) + x_quant = x_quant.reshape(M, N) + x_scale = x_scale.squeeze(-1) + + return x_quant, x_scale + + +def upcast(x, s, dtype, group_size=128): + x_N = x.shape[1] + x = x.reshape(-1, x_N // group_size, group_size).to(torch.float32) * s.reshape( + -1, s.shape[1], 1 + ) + x = x.reshape(-1, x_N) + return x.to(dtype=dtype) + + +def run_torch_rms_fp8_group_quant( + x1, w1, eps1, x2, w2, eps2, res1, dtype_quant, group_size +): + s = x1 + res1 + y1 = rmsnorm(s, w1, eps1) + y2 = rmsnorm(x2, w2, eps2) + y1_q, y1_s = per_token_fp8_group_quant(y1, dtype_quant, group_size) + return (y1_q, y1_s), y1.to(x1.dtype), y2.to(x1.dtype), s.to(x1.dtype) + + +def generate_fused_rms_quant_data(M, N1, N2, dtype=torch.bfloat16): + x1 = torch.randn((M, N1), dtype=dtype, device="cuda") / 10 + x2 = torch.randn((M, N2), dtype=dtype, device="cuda") / 10 + w1 = torch.ones((N1,), dtype=torch.float32, device="cuda") + w2 = torch.ones((N2,), dtype=torch.float32, device="cuda") + res1 = torch.randn((M, N1), dtype=dtype, device="cuda") / 10 + return x1, w1, x2, w2, res1 + + +@pytest.mark.parametrize("M", [1, 32, 256]) +@pytest.mark.parametrize("N1, N2", [(128, 128), (128, 7168), (7168, 7168)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_rms_fp8_group_quant(M: int, N1: int, N2: int, dtype): + group_size = 128 + dtype_quant = aiter.dtypes.fp8 + x1, w1, x2, w2, res1 = generate_fused_rms_quant_data(M, N1, N2, dtype) + + (y1_q_torch, y1_s_torch), y1_torch, y2_torch, y1_res_torch = ( + run_torch_rms_fp8_group_quant( + x1, w1, 1e-6, x2, w2, 1e-6, res1, dtype_quant, group_size + ) + ) + + (y1_q_triton, y1_s_triton), y1_triton, y2_triton, y1_res_triton = ( + fused_rms_fp8_group_quant( + x1, + w1, + 1e-6, + inp2=x2, + inp2_weight=w2, + inp2_epsilon=1e-6, + group_size=group_size, + dtype_quant=dtype_quant, + res1=res1, + output_unquantized_inp1=True, + ) + ) + + torch.testing.assert_close(y1_torch, y1_triton, atol=0.1, rtol=0.1) + torch.testing.assert_close(y2_torch, y2_triton, atol=0.1, rtol=0.1) + torch.testing.assert_close(y1_res_torch, y1_res_triton, atol=0.1, rtol=0.1) + + y1_upcast_torch = upcast( + y1_q_torch, y1_s_torch, dtype=torch.float32, group_size=group_size + ) + y1_upcast_triton = upcast( + y1_q_triton, y1_s_triton, dtype=torch.float32, group_size=group_size + ) + torch.testing.assert_close(y1_upcast_torch, y1_upcast_triton, atol=0.1, rtol=0.1) + + +def run_torch_flatten_fp8_group_quant(x, dtype_quant, group_size): + y_q, y_s = per_token_fp8_group_quant( + x.reshape(x.shape[0], -1), dtype_quant, group_size + ) + return y_q, y_s + + +@pytest.mark.parametrize("M", [1, 32, 256]) +@pytest.mark.parametrize("N1, N2", [(16, 128)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_flatten_fp8_group_quant(M: int, N1: int, N2: int, dtype): + group_size = 128 + dtype_quant = aiter.dtypes.fp8 + x = torch.randn((N1, M, N2), dtype=dtype, device="cuda") / 10 + x = x.transpose(0, 1) + + y_q_torch, y_s_torch = run_torch_flatten_fp8_group_quant(x, dtype_quant, group_size) + + y_q_triton, y_s_triton = fused_flatten_fp8_group_quant( + x, + group_size=group_size, + dtype_quant=dtype_quant, + ) + + y_upcast_torch = upcast( + y_q_torch, y_s_torch, dtype=torch.float32, group_size=group_size + ) + y_upcast_triton = upcast( + y_q_triton, y_s_triton, dtype=torch.float32, group_size=group_size + ) + torch.testing.assert_close(y_upcast_torch, y_upcast_triton, atol=0.1, rtol=0.1) + + +def run_torch_reduce_act_mul_fp8_group_quant( + x, x2, activation, dtype, dtype_quant, group_size=128 +): + x = x.clone() + y2 = None + if x.dim() == 3: + x = x.sum(axis=0) + y2 = x2.sum(axis=0).to(dtype=dtype) + else: + assert x2 is None, "x2 must be None in x.dim() == 2 cases" + n = x.shape[1] // 2 + x, x_mul = x.split([n, n], dim=-1) + if activation == "silu": + x = F.silu(x) * x_mul + elif activation == "gelu": + x = F.gelu(x) * x_mul + + y_q, y_s = per_token_fp8_group_quant(x, dtype_quant, group_size) + + return (y_q, y_s), y2 + + +def generate_fused_reduce_act_mul_fp8_group_quant( + M: int, + N1: int, + dtype=torch.bfloat16, + SPK: int = 1, + N2: int = 1, +): + if SPK == 1: + x = torch.randn((M, N1 * 2), dtype=dtype).cuda() / 10 + else: + x = torch.randn((SPK, M, N1 * 2), dtype=torch.float32).cuda() / 10 + x2 = None + if SPK > 1: + x2 = torch.randn((SPK, M, N2), dtype=torch.float32).cuda() / 10 + + return x, x2 + + +@pytest.mark.parametrize("M", [1, 32, 256, 131072]) +@pytest.mark.parametrize("N1, N2", [(256, 256)]) +@pytest.mark.parametrize("SPK", [1, 4, 14]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("activation", ["silu", "gelu"]) +def test_fused_reduce_act_mul_fp8_group_quant( + M: int, N1: int, N2: int, SPK: int, dtype, activation +): + group_size = 128 + dtype_quant = aiter.dtypes.fp8 + + x, x2 = generate_fused_reduce_act_mul_fp8_group_quant( + M, N1, dtype=dtype, SPK=SPK, N2=N2 + ) + + (y_q_torch, y_s_torch), y2_torch = run_torch_reduce_act_mul_fp8_group_quant( + x, x2, activation, dtype, dtype_quant, group_size + ) + + (y_q_triton, y_s_triton), y2_triton = fused_reduce_act_mul_fp8_group_quant( + x, + activation=activation, + x2=x2, + group_size=group_size, + dtype_quant=dtype_quant, + dtype=dtype, + ) + + torch.testing.assert_close(y2_torch, y2_triton, atol=0.1, rtol=0.1) + + y_upcast_torch = upcast( + y_q_torch, y_s_torch, dtype=torch.float32, group_size=group_size + ) + y_upcast_triton = upcast( + y_q_triton, y_s_triton, dtype=torch.float32, group_size=group_size + ) + torch.testing.assert_close(y_upcast_torch, y_upcast_triton, atol=0.1, rtol=0.1) diff --git a/op_tests/triton_tests/test_fused_gemm_a8w8_blockscale_a16w16.py b/op_tests/triton_tests/test_fused_gemm_a8w8_blockscale_a16w16.py new file mode 100644 index 0000000000..13fdf74cc8 --- /dev/null +++ b/op_tests/triton_tests/test_fused_gemm_a8w8_blockscale_a16w16.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import pytest +from aiter.ops.triton.fused_gemm_a8w8_blockscale_a16w16 import ( + fused_gemm_a8w8_blockscale_a16w16, +) +from op_tests.triton_tests.test_gemm_a8w8_blockscale import ( + generate_gemm_a8w8_blockscale_inputs, +) +from op_tests.triton_tests.test_gemm_a8w8_blockscale import run_torch as run_torch_fp8 +from op_tests.triton_tests.test_gemm_a16w16 import generate_gemm_a16w16_inputs +import torch.nn.functional as F + + +block_shape = (128, 128) + + +def run_torch( + x_fp8, + w_fp8, + x_fp8_scale, + w_fp8_scale, + x_bf16, + w_bf16, + bias_fp8, + bias_bf16, + dtype=torch.bfloat16, +): + y_fp8 = run_torch_fp8(x_fp8, w_fp8, x_fp8_scale, w_fp8_scale, dtype) + if bias_fp8 is not None: + y_fp8 += bias_fp8 + y_bf16 = F.linear(x_bf16, w_bf16, bias=bias_bf16) + return y_fp8, y_bf16 + + +def run_triton( + x_fp8, + w_fp8, + x_fp8_scale, + w_fp8_scale, + x_bf16, + w_bf16, + bias_fp8, + bias_bf16, + dtype=torch.bfloat16, + y_fp8=None, + y_bf16=None, + skip_reduce=False, +): + return fused_gemm_a8w8_blockscale_a16w16( + x_fp8, + w_fp8, + x_fp8_scale, + w_fp8_scale, + x_bf16, + w_bf16, + bias_fp8=bias_fp8, + bias_bf16=bias_bf16, + dtype=dtype, + y_fp8=y_fp8, + y_bf16=y_bf16, + skip_reduce=skip_reduce, + ) + + +def get_x_vals(): + + x_vals = [(1, 1, 1, 128)] # minimal case + x_vals += [ + (m, n1, n2, k) + for k in [1024, 8192, 7168] + for n2 in [256, 512] + for n1 in [256, 512] + for m in [1, 8, 32, 64, 128, 8192] + ] + return x_vals + + +@pytest.mark.parametrize("M, N1, N2, K", get_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("output", [True, False]) +@pytest.mark.parametrize("skip_reduce", [True, False]) +def test_gemm(dtype, M, N1, N2, K, output, skip_reduce): + block_shape_n, block_shape_k = block_shape + + x_fp8, w_fp8, x_fp8_scale, w_fp8_scale, y_fp8 = ( + generate_gemm_a8w8_blockscale_inputs( + M, + N1, + K, + block_shape_n, + block_shape_k, + dtype=dtype, + output=output, + ) + ) + x_bf16, w_bf16, bias_bf16, _, y_bf16 = generate_gemm_a16w16_inputs( + M, N2, K, dtype, output=output, bias=True + ) + bias_bf16 = torch.randn((N2,), dtype=bias_bf16.dtype, device=bias_bf16.device) + bias_fp8 = torch.randn((N1,), dtype=bias_bf16.dtype, device=bias_bf16.device) + y_torch_fp8, y_torch_bf16 = run_torch( + x_fp8, + w_fp8, + x_fp8_scale, + w_fp8_scale, + x_bf16, + w_bf16, + bias_fp8, + bias_bf16, + dtype, + ) + y_triton_fp8, y_triton_bf16 = run_triton( + x_fp8, + w_fp8, + x_fp8_scale, + w_fp8_scale, + x_bf16, + w_bf16, + bias_fp8, + bias_bf16, + dtype, + y_fp8, + y_bf16, + skip_reduce=skip_reduce, + ) + + if y_triton_fp8.dim() == 3: + y_triton_fp8 = y_triton_fp8.sum(axis=0).to(dtype=dtype) + y_triton_bf16 = y_triton_bf16.sum(axis=0).to(dtype=dtype) + + triton.testing.assert_close(y_torch_fp8, y_triton_fp8, atol=0.1, rtol=1e-1) + triton.testing.assert_close(y_torch_bf16, y_triton_bf16, atol=0.1, rtol=1e-1)