diff --git a/aiter/ops/triton/_triton_kernels/activation.py b/aiter/ops/triton/_triton_kernels/activation.py index 43c9324dc2..72b0006baa 100644 --- a/aiter/ops/triton/_triton_kernels/activation.py +++ b/aiter/ops/triton/_triton_kernels/activation.py @@ -1,4 +1,5 @@ from .quant import _mxfp4_quant_op +from .fused_fp8_quant import _fp8_quant_op import triton import triton.language as tl @@ -188,3 +189,87 @@ def _act_mul_and_dynamic_mxfp4_quant_kernel( bs_e8m0, mask=bs_mask, ) + + +@triton.heuristics( + { + "EVEN_N": lambda args: args["N"] % args["BLOCK_SIZE_N"] == 0, + } +) +@triton.jit +def _act_mul_and_dynamic_fp8_group_quant_kernel( + x_ptr, + x_fp8_ptr, + x_bs_ptr, + stride_x_m_in, + stride_x_n_in, + stride_x_fp8_m_in, + stride_x_fp8_n_in, + stride_bs_m_in, + stride_bs_n_in, + N, + ACTIVATION: tl.constexpr, + scaleN: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, + DTYPE_MAX: tl.constexpr, + DTYPE_MIN: tl.constexpr, + EVEN_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + # cast strides to int64, in case M*N > max int32 + stride_x_m = tl.cast(stride_x_m_in, tl.int64) + stride_x_n = tl.cast(stride_x_n_in, tl.int64) + stride_x_fp8_m = tl.cast(stride_x_fp8_m_in, tl.int64) + stride_x_fp8_n = tl.cast(stride_x_fp8_n_in, tl.int64) + stride_bs_m = tl.cast(stride_bs_m_in, tl.int64) + stride_bs_n = tl.cast(stride_bs_n_in, tl.int64) + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE + + x_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + x_offs = pid_m * stride_x_m + x_offs_n * stride_x_n + + if EVEN_N: + a = tl.load(x_ptr + x_offs, cache_modifier=".cg").to(tl.float32) + b = tl.load(x_ptr + x_offs + stride_x_n * N, cache_modifier=".cg").to( + tl.float32 + ) + else: + x_mask = x_offs_n < N + a = tl.load(x_ptr + x_offs, mask=x_mask, cache_modifier=".cg").to(tl.float32) + # a and b can share the same mask + b = tl.load( + x_ptr + x_offs + stride_x_n * N, mask=x_mask, cache_modifier=".cg" + ).to(tl.float32) + + x = _apply_activation_from_str(a, ACTIVATION) * b + + x_fp8, x_bs = _fp8_quant_op( + x, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN + ) + x_fp8 = tl.ravel(x_fp8) + x_bs = tl.ravel(x_bs) + + out_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_offs = pid_m * stride_x_fp8_m + out_offs_n * stride_x_fp8_n + + if EVEN_N: + tl.store(x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.dtype.element_ty)) + else: + out_mask = out_offs_n < N + tl.store( + x_fp8_ptr + out_offs, x_fp8.to(x_fp8_ptr.dtype.element_ty), mask=out_mask + ) + + bs_offs_n = pid_n * NUM_QUANT_BLOCKS + tl.arange(0, NUM_QUANT_BLOCKS) + bs_offs = pid_m * stride_bs_m + bs_offs_n * stride_bs_n + if EVEN_N: + tl.store(x_bs_ptr + bs_offs, x_bs.to(x_bs_ptr.dtype.element_ty)) + else: + bs_mask = bs_offs_n < scaleN + tl.store( + x_bs_ptr + bs_offs, + x_bs.to(x_bs_ptr.dtype.element_ty), + mask=bs_mask, + ) diff --git a/aiter/ops/triton/_triton_kernels/fused_add_rmsnorm_pad.py b/aiter/ops/triton/_triton_kernels/fused_add_rmsnorm_pad.py new file mode 100644 index 0000000000..f782b6e8e7 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/fused_add_rmsnorm_pad.py @@ -0,0 +1,81 @@ +import triton +import triton.language as tl + + +@triton.jit +def _rmsmorm_op(row, weight, n_cols, epsilon): + row_norm = row * row + row_norm = tl.sum(row_norm, axis=-1) + norm_factor = tl.math.rsqrt((row_norm / n_cols) + epsilon) + + rms_norm = row * norm_factor * weight + return rms_norm + + +@triton.jit +def _fused_add_rmsnorm_pad( + x_ptr, + res_ptr, + out_ptr, + res_out_ptr, + weight_ptr, + eps, + M, + N, + N_OUT, + x_stride_m, + x_stride_n, + res_stride_m, + res_stride_n, + out_stride_m, + out_stride_n, + res_out_stride_m, + res_out_stride_n, + HAS_RES: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + tl.assume(x_stride_m > 0) + tl.assume(x_stride_n > 0) + tl.assume(res_stride_m > 0) + tl.assume(res_stride_n > 0) + tl.assume(out_stride_m > 0) + tl.assume(out_stride_n > 0) + + pid_m = tl.program_id(0) + tl.assume(pid_m >= 0) + + n_offs = tl.arange(0, BLOCK_SIZE_N) + mask = n_offs < N + x = tl.load( + x_ptr + pid_m * x_stride_m + n_offs * x_stride_n, + mask=mask, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + if HAS_RES: + res = tl.load( + res_ptr + pid_m * res_stride_m + n_offs * res_stride_n, + mask=mask, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + x = x + res + + w = tl.load( + weight_ptr + n_offs, + mask=mask, + other=0.0, + ).to(tl.float32) + out = _rmsmorm_op(x, w, N, eps).to(out_ptr.dtype.element_ty) + + tl.store( + out_ptr + pid_m * out_stride_m + n_offs * out_stride_n, + out, + mask=(n_offs < N_OUT), + ) + if HAS_RES: + tl.store( + res_out_ptr + pid_m * res_out_stride_m + n_offs * res_out_stride_n, + x, + mask=mask, + ) diff --git a/aiter/ops/triton/_triton_kernels/fused_fp8_quant.py b/aiter/ops/triton/_triton_kernels/fused_fp8_quant.py new file mode 100644 index 0000000000..088b1ce415 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/fused_fp8_quant.py @@ -0,0 +1,345 @@ +import triton +import triton.language as tl + + +@triton.jit +def _rmsmorm_op(row, weight, n_cols, epsilon): + row_norm = row * row + row_norm = tl.sum(row_norm, axis=-1) + norm_factor = tl.math.rsqrt((row_norm / n_cols) + epsilon) + + rms_norm = row * norm_factor * weight + return rms_norm + + +@triton.jit +def _fp8_quant_op( + x, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, + DTYPE_MAX: tl.constexpr, + DTYPE_MIN: tl.constexpr, +): + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE + x = x.reshape(BLOCK_SIZE_M, NUM_QUANT_BLOCKS, QUANT_BLOCK_SIZE) + m = tl.maximum(tl.max(tl.abs(x), axis=-1), 1e-10) + scale_out = m.to(tl.float32) / DTYPE_MAX + scale_recip = 1.0 / scale_out.reshape(BLOCK_SIZE_M, NUM_QUANT_BLOCKS, 1) + x = tl.clamp(x * scale_recip, DTYPE_MIN, DTYPE_MAX) + + return x, scale_out + + +@triton.jit +def _fused_rms_fp8_group_quant_kernel( + inp1_ptr, + weight1_ptr, + inp2_ptr, + weight2_ptr, + res1_ptr, + out1_fp8_ptr, + out1_bs_ptr, + out2_ptr, + out_res1_ptr, + out1_ptr, + eps1, + eps2, + n_rows, + inp1_n_cols, + inp2_n_cols, + inp1_row_stride, + inp2_row_stride, + inp1_col_stride, + inp2_col_stride, + res1_row_stride, + res1_col_stride, + out1_fp8_row_stride, + out1_fp8_col_stride, + out1_bs_row_stride, + out1_bs_col_stride, + out2_row_stride, + out2_col_stride, + out_res1_row_stride, + out_res1_col_stride, + out1_row_stride, + out1_col_stride, + BLOCK_SIZE_N: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, + DTYPE_MAX: tl.constexpr, + DTYPE_MIN: tl.constexpr, + HAVE_SECOND_INPUT: tl.constexpr, + FIRST_INPUT_RES: tl.constexpr, + FIRST_INPUT_OUT: tl.constexpr, +): + m_pid = tl.program_id(0) + n_offs = tl.arange(0, BLOCK_SIZE_N) + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE + + mask1 = n_offs < inp1_n_cols + inp1 = tl.load( + inp1_ptr + m_pid * inp1_row_stride + n_offs * inp1_col_stride, + mask=mask1, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + if FIRST_INPUT_RES: + res1 = tl.load( + res1_ptr + m_pid * res1_row_stride + n_offs * res1_col_stride, + mask=mask1, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + inp1 = inp1 + res1 + + w1 = tl.load(weight1_ptr + n_offs, mask=mask1, other=0.0).to(tl.float32) + + norm1 = _rmsmorm_op(inp1, w1, inp1_n_cols, eps1) + + if FIRST_INPUT_OUT: + mask1 = n_offs < inp1_n_cols + tl.store( + out1_ptr + m_pid * out1_row_stride + n_offs * out1_col_stride, + norm1, + mask=mask1, + ) + + out1_fp8, out1_block_scales = _fp8_quant_op( + norm1, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN + ) + out1_fp8 = tl.ravel(out1_fp8) + out1_block_scales = tl.ravel(out1_block_scales) + + # store the results + tl.store( + out1_fp8_ptr + m_pid * out1_fp8_row_stride + n_offs * out1_fp8_col_stride, + out1_fp8.to(out1_fp8_ptr.dtype.element_ty), + mask=mask1, + ) + g_offs = tl.arange(0, NUM_QUANT_BLOCKS) + num_bs_cols = (inp1_n_cols + QUANT_BLOCK_SIZE - 1) // QUANT_BLOCK_SIZE + tl.store( + out1_bs_ptr + m_pid * out1_bs_row_stride + g_offs * out1_bs_col_stride, + out1_block_scales.to(out1_bs_ptr.dtype.element_ty), + mask=g_offs < num_bs_cols, + ) + if HAVE_SECOND_INPUT: + mask2 = n_offs < inp2_n_cols + inp2 = tl.load( + inp2_ptr + m_pid * inp2_row_stride + n_offs * inp2_col_stride, + mask=mask2, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + w2 = tl.load(weight2_ptr + n_offs, mask=mask2, other=0.0).to(tl.float32) + norm2 = _rmsmorm_op(inp2, w2, inp2_n_cols, eps2) + tl.store( + out2_ptr + m_pid * out2_row_stride + n_offs * out2_col_stride, + norm2, + mask=mask2, + ) + + if FIRST_INPUT_RES: + inp1 = inp1.to(out_res1_ptr.dtype.element_ty) + tl.store( + out_res1_ptr + m_pid * out_res1_row_stride + n_offs * out_res1_col_stride, + inp1, + mask=mask1, + ) + + +@triton.jit +def _fused_flatten_fp8_group_quant_kernel( + x_ptr, + out_ptr, + out_scales_ptr, + x_stride_m, + x_stride_n1, + x_stride_n2, + out_stride_m, + out_stride_n, + out_scales_stride_m, + out_scales_stride_n, + N2, + BLOCK_SIZE_N2: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, + DTYPE_MAX: tl.constexpr, + DTYPE_MIN: tl.constexpr, +): + m = tl.program_id(0) + n1 = tl.program_id(1) + + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N2 // QUANT_BLOCK_SIZE + + n2_offs = tl.arange(0, BLOCK_SIZE_N2) + x_offs = m * x_stride_m + n1 * x_stride_n1 + n2_offs * x_stride_n2 + x = tl.load(x_ptr + x_offs, mask=n2_offs < N2) + + out, out_block_scales = _fp8_quant_op( + x, 1, BLOCK_SIZE_N2, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN + ) + out = tl.ravel(out) + out_block_scales = tl.ravel(out_block_scales) + + tl.store( + out_ptr + m * out_stride_m + (n1 * BLOCK_SIZE_N2 + n2_offs) * out_stride_n, + out.to(out_ptr.dtype.element_ty), + mask=n2_offs < N2, + ) + block_scale_offs = tl.arange(0, NUM_QUANT_BLOCKS) + tl.store( + out_scales_ptr + + m * out_scales_stride_m + + (n1 * NUM_QUANT_BLOCKS + block_scale_offs) * out_scales_stride_n, + out_block_scales.to(out_scales_ptr.dtype.element_ty), + mask=block_scale_offs < tl.cdiv(N2, QUANT_BLOCK_SIZE), + ) + + +@triton.jit +def _fused_reduce_act_mul_fp8_group_quant( + x_ptr, + y_ptr, + y_scale_ptr, + x2_ptr, + y2_ptr, + M, + N1, + N2, + stride_x_spk, + stride_x_m, + stride_x_n, + stride_y_m, + stride_y_n, + stride_y_scale_m, + stride_y_scale_n, + stride_x2_spk, + stride_x2_m, + stride_x2_n, + stride_y2_m, + stride_y2_n, + # Meta-parameters + ACTIVATION: tl.constexpr, + BLOCK_SIZE_M2: tl.constexpr, + BLOCK_SIZE_N1: tl.constexpr, + BLOCK_SIZE_N2: tl.constexpr, + QUANT_BLOCK_SIZE: tl.constexpr, + DTYPE_MAX: tl.constexpr, + DTYPE_MIN: tl.constexpr, + X_HAS_SPLITK: tl.constexpr, + X_NUM_KSPLIT: tl.constexpr, + X_NUM_KSPLIT_POW2: tl.constexpr, + X_MASK: tl.constexpr, +): + + tl.assume(stride_x_spk > 0) + tl.assume(stride_x_m > 0) + tl.assume(stride_x_n > 0) + tl.assume(stride_y_m > 0) + tl.assume(stride_y_n > 0) + tl.assume(stride_y_scale_m > 0) + tl.assume(stride_y_scale_n > 0) + tl.assume(stride_x2_spk > 0) + tl.assume(stride_x2_m > 0) + tl.assume(stride_x2_n > 0) + tl.assume(stride_y2_m > 0) + tl.assume(stride_y2_n > 0) + + m_pid = tl.program_id(axis=0) + if X_HAS_SPLITK and m_pid >= M: + pid2 = m_pid - M + num_pid_n2 = tl.cdiv(N2, BLOCK_SIZE_N2) + pid_m2 = pid2 // num_pid_n2 + pid_n2 = pid2 % num_pid_n2 + offs_m2 = (pid_m2 * BLOCK_SIZE_M2 + tl.arange(0, BLOCK_SIZE_M2)) % M + offs_n2 = (pid_n2 * BLOCK_SIZE_N2 + tl.arange(0, BLOCK_SIZE_N2)) % N2 + offs_spk = tl.arange(0, X_NUM_KSPLIT_POW2) + x2_ptrs = ( + x2_ptr + + offs_spk[:, None, None] * stride_x2_spk + + offs_m2[None, :, None] * stride_x2_m + + offs_n2[None, None, :] * stride_x2_n + ) + if X_NUM_KSPLIT_POW2 == X_NUM_KSPLIT: + x2 = tl.load(x2_ptrs) + else: + x2 = tl.load( + x2_ptrs, mask=offs_spk[:, None, None] < X_NUM_KSPLIT, other=0.0 + ) + x2 = tl.sum(x2, axis=0) + + x2 = x2.to(y2_ptr.type.element_ty) + + y2_out_ptrs = ( + y2_ptr + (offs_m2[:, None] * stride_y2_m) + (offs_n2[None, :] * stride_y2_n) + ) + + tl.store(y2_out_ptrs, x2) + return + + n_offs = tl.arange(0, BLOCK_SIZE_N1) + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N1 // QUANT_BLOCK_SIZE + + mask = None + other = None + if X_HAS_SPLITK: + offs_spk = tl.arange(0, X_NUM_KSPLIT_POW2) + x_ptrs = ( + x_ptr + + offs_spk[:, None] * stride_x_spk + + m_pid * stride_x_m + + n_offs[None, :] * stride_x_n + ) + if X_MASK: + mask = (offs_spk[:, None] < X_NUM_KSPLIT) & (n_offs[None, :] < N1) + other = 0.0 + else: + mask = offs_spk[:, None] < X_NUM_KSPLIT + other = 0.0 + else: + x_ptrs = x_ptr + m_pid * stride_x_m + n_offs * stride_x_n + if X_MASK: + mask = n_offs < N1 + other = 0.0 + + x = tl.load( + x_ptrs, + mask=mask, + other=other, + cache_modifier=".cg", + ).to(tl.float32) + x_mul = tl.load( + x_ptrs + N1 * stride_x_n, + mask=mask, + other=other, + cache_modifier=".cg", + ).to(tl.float32) + + if X_HAS_SPLITK: + x = tl.sum(x, axis=0) + x_mul = tl.sum(x_mul, axis=0) + + x = ACTIVATION(x) * x_mul + + y, y_scale = _fp8_quant_op( + x, 1, BLOCK_SIZE_N1, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN + ) + y = tl.ravel(y) + y_scale = tl.ravel(y_scale) + + if X_MASK: + mask = n_offs < N1 + else: + mask = n_offs < N1 + tl.store( + y_ptr + m_pid * stride_y_m + n_offs * stride_y_n, + y.to(y_ptr.dtype.element_ty), + mask=mask, + ) + g_offs = tl.arange(0, NUM_QUANT_BLOCKS) + num_bs_cols = (N1 + QUANT_BLOCK_SIZE - 1) // QUANT_BLOCK_SIZE + tl.store( + y_scale_ptr + m_pid * stride_y_scale_m + g_offs * stride_y_scale_n, + y_scale.to(y_scale_ptr.dtype.element_ty), + mask=g_offs < num_bs_cols, + ) diff --git a/aiter/ops/triton/_triton_kernels/fused_gemm_a8w8_blockscale_a16w16.py b/aiter/ops/triton/_triton_kernels/fused_gemm_a8w8_blockscale_a16w16.py new file mode 100644 index 0000000000..6ac5b5ee46 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/fused_gemm_a8w8_blockscale_a16w16.py @@ -0,0 +1,452 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional +import functools +import json +import os +import torch +import triton +import triton.language as tl +from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd +from ..utils._triton import arch_info +from ..utils.core import AITER_TRITON_CONFIGS_PATH + + +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, + "GRID_MN_FP8": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"]) + * triton.cdiv(args["N_fp8"], args["BLOCK_SIZE_N"]), + "GRID_MN_BF16": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"]) + * triton.cdiv(args["N_bf16"], args["BLOCK_SIZE_N"]), + } +) +@triton.jit +def _fused_gemm_a8w8_blockscale_a16w16_kernel( + # Pointers to matrices + a_fp8_ptr, + b_fp8_ptr, + bias_fp8_ptr, + a_fp8_scale_ptr, + b_fp8_scale_ptr, + c_fp8_ptr, + a_bf16_ptr, + b_bf16_ptr, + bias_bf16_ptr, + c_bf16_ptr, + # Matrix dimensions + M, + N_fp8, + N_bf16, + K, + stride_a_fp8_m, + stride_a_fp8_k, + stride_b_fp8_k, + stride_b_fp8_n, + stride_a_fp8_scale_m, + stride_a_fp8_scale_k, + stride_b_fp8_scale_k, + stride_b_fp8_scale_n, + stride_c_fp8_k, + stride_c_fp8_m, + stride_c_fp8_n, + stride_a_bf16_m, + stride_a_bf16_k, + stride_b_bf16_k, + stride_b_bf16_n, + stride_c_bf16_k, + stride_c_bf16_m, + stride_c_bf16_n, + # Meta-parameters + GROUP_K: tl.constexpr, + GROUP_N: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, + ADD_BIAS_FP8: tl.constexpr, + ADD_BIAS_BF16: tl.constexpr, + EVEN_K: tl.constexpr, + GRID_MN_FP8: tl.constexpr, + GRID_MN_BF16: tl.constexpr, + SKIP_REDUCE: tl.constexpr, + cache_modifier: tl.constexpr, +): + + tl.assume(stride_a_fp8_m > 0) + tl.assume(stride_a_fp8_k > 0) + tl.assume(stride_b_fp8_k > 0) + tl.assume(stride_b_fp8_n > 0) + tl.assume(stride_c_fp8_k > 0) + tl.assume(stride_c_fp8_m > 0) + tl.assume(stride_c_fp8_n > 0) + tl.assume(stride_a_fp8_scale_m > 0) + tl.assume(stride_a_fp8_scale_k > 0) + tl.assume(stride_b_fp8_scale_k > 0) + tl.assume(stride_b_fp8_scale_n > 0) + + tl.assume(stride_a_bf16_m > 0) + tl.assume(stride_a_bf16_k > 0) + tl.assume(stride_b_bf16_k > 0) + tl.assume(stride_b_bf16_n > 0) + tl.assume(stride_c_bf16_m > 0) + tl.assume(stride_c_bf16_n > 0) + + pid_unified = tl.program_id(axis=0) + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n_fp8 = tl.cdiv(N_fp8, BLOCK_SIZE_N) + num_pid_n_bf16 = tl.cdiv(N_bf16, BLOCK_SIZE_N) + num_pid_n = num_pid_n_fp8 + num_pid_n_bf16 + + if NUM_KSPLIT == 1: + GRID_MN: tl.constexpr = GRID_MN_FP8 + GRID_MN_BF16 + remap_xcd(pid, GRID_MN) + + pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(pid_k >= 0) + + if (pid_k * SPLITK_BLOCK_SIZE) < K: + + num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE, BLOCK_SIZE_K) + acc_dtype = tl.float32 if c_fp8_ptr.type.element_ty != tl.int8 else tl.int32 + + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k_split = pid_k * SPLITK_BLOCK_SIZE + offs_k + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_ks_step = BLOCK_SIZE_K // GROUP_K + + if pid_n < num_pid_n_fp8: + offs_b_fp8_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N_fp8 + a_fp8_ptrs = a_fp8_ptr + ( + offs_am[:, None] * stride_a_fp8_m + + offs_k_split[None, :] * stride_a_fp8_k + ) + b_fp8_ptrs = b_fp8_ptr + ( + offs_k_split[:, None] * stride_b_fp8_k + + offs_b_fp8_n[None, :] * stride_b_fp8_n + ) + + offs_ks = (pid_k * SPLITK_BLOCK_SIZE) // GROUP_K + a_scale_ptrs = ( + a_fp8_scale_ptr + + offs_am * stride_a_fp8_scale_m + + offs_ks * stride_a_fp8_scale_k + ) + offs_bsn = offs_b_fp8_n // GROUP_N + b_scale_ptrs = ( + b_fp8_scale_ptr + + offs_ks * stride_b_fp8_scale_k + + offs_bsn * stride_b_fp8_scale_n + ) + + if ADD_BIAS_FP8: + if NUM_KSPLIT == 1 or (SKIP_REDUCE and pid_k == 0): + accumulator_fp8 = tl.load(bias_fp8_ptr + offs_b_fp8_n).to( + dtype=acc_dtype + ) + accumulator_fp8 = tl.broadcast_to( + accumulator_fp8[None, :], (BLOCK_SIZE_M, BLOCK_SIZE_N) + ) + else: + accumulator_fp8 = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype + ) + else: + accumulator_fp8 = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype + ) + + for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): + if EVEN_K: + a = tl.load(a_fp8_ptrs) + b = tl.load(b_fp8_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_fp8_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_fp8_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0, + cache_modifier=cache_modifier, + ) + + a_scale = tl.load(a_scale_ptrs) + b_scale = tl.load(b_scale_ptrs) + + accumulator_fp8 += ( + tl.dot(a, b, input_precision="ieee") + * a_scale[:, None] + * b_scale[None, :] + ) + + a_fp8_ptrs += BLOCK_SIZE_K * stride_a_fp8_k + b_fp8_ptrs += BLOCK_SIZE_K * stride_b_fp8_k + a_scale_ptrs += offs_ks_step * stride_a_fp8_scale_k + b_scale_ptrs += offs_ks_step * stride_b_fp8_scale_k + + c_fp8 = accumulator_fp8.to(c_fp8_ptr.type.element_ty) + + offs_cm = pid_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( + tl.int64 + ) + offs_c_fp8_n = pid_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ).to(tl.int64) + c_fp8_ptrs = ( + c_fp8_ptr + + stride_c_fp8_m * offs_cm[:, None] + + stride_c_fp8_n * offs_c_fp8_n[None, :] + + pid_k * stride_c_fp8_k + ) + c_fp8_mask = (offs_cm[:, None] < M) & (offs_c_fp8_n[None, :] < N_fp8) + tl.store(c_fp8_ptrs, c_fp8, mask=c_fp8_mask) + else: + pid_n -= num_pid_n_fp8 + + offs_b_bf16_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N_bf16 + a_ptrs = a_bf16_ptr + ( + offs_am[:, None] * stride_a_bf16_m + + offs_k_split[None, :] * stride_a_bf16_k + ) + b_ptrs = b_bf16_ptr + ( + offs_k_split[:, None] * stride_b_bf16_k + + offs_b_bf16_n[None, :] * stride_b_bf16_n + ) + + if ADD_BIAS_BF16: + if NUM_KSPLIT == 1 or (SKIP_REDUCE and pid_k == 0): + accumulator_bf16 = tl.load(bias_bf16_ptr + offs_b_bf16_n).to( + dtype=acc_dtype + ) + accumulator_bf16 = tl.broadcast_to( + accumulator_bf16[None, :], (BLOCK_SIZE_M, BLOCK_SIZE_N) + ) + else: + accumulator_bf16 = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype + ) + else: + accumulator_bf16 = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype + ) + + for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs, cache_modifier=cache_modifier) + else: + a = tl.load( + a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0 + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0, + cache_modifier=cache_modifier, + ) + + accumulator_bf16 += tl.dot(a, b, input_precision="ieee") + + a_ptrs += BLOCK_SIZE_K * stride_a_bf16_k + b_ptrs += BLOCK_SIZE_K * stride_b_bf16_k + + c_bf16 = accumulator_bf16.to(c_bf16_ptr.type.element_ty) + + # Write back the block of the output matrix C with masks. + offs_cm = pid_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( + tl.int64 + ) + offs_c_bf16_n = pid_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ).to(tl.int64) + c_bf16_ptrs = ( + c_bf16_ptr + + stride_c_bf16_m * offs_cm[:, None] + + stride_c_bf16_n * offs_c_bf16_n[None, :] + + pid_k * stride_c_bf16_k + ) + c_bf16_mask = (offs_cm[:, None] < M) & (offs_c_bf16_n[None, :] < N_bf16) + tl.store(c_bf16_ptrs, c_bf16, mask=c_bf16_mask) + + +@triton.jit +def _fused_gemm_a8w8_blockscale_a16w16_reduce_kernel( + bias_fp8_ptr, + c_fp8_in_ptr, + c_fp8_out_ptr, + bias_bf16_ptr, + c_bf16_in_ptr, + c_bf16_out_ptr, + M, + N_fp8, + N_bf16, + stride_c_fp8_in_k, + stride_c_fp8_in_m, + stride_c_fp8_in_n, + stride_c_fp8_out_m, + stride_c_fp8_out_n, + stride_c_bf16_in_k, + stride_c_bf16_in_m, + stride_c_bf16_in_n, + stride_c_bf16_out_m, + stride_c_bf16_out_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + ACTUAL_KSPLIT: tl.constexpr, + MAX_KSPLIT: tl.constexpr, + ADD_BIAS_FP8: tl.constexpr, + ADD_BIAS_BF16: tl.constexpr, +): + + tl.assume(stride_c_fp8_in_k > 0) + tl.assume(stride_c_fp8_in_m > 0) + tl.assume(stride_c_fp8_in_n > 0) + tl.assume(stride_c_fp8_out_m > 0) + tl.assume(stride_c_fp8_out_n > 0) + + tl.assume(stride_c_bf16_in_k > 0) + tl.assume(stride_c_bf16_in_m > 0) + tl.assume(stride_c_bf16_in_n > 0) + tl.assume(stride_c_bf16_out_m > 0) + tl.assume(stride_c_bf16_out_n > 0) + + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + + num_pid_n_fp8 = tl.cdiv(N_fp8, BLOCK_SIZE_N) + offs_k = tl.arange(0, MAX_KSPLIT) + acc_dtype = tl.float32 if c_fp8_in_ptr.type.element_ty != tl.int8 else tl.int32 + + if pid_n < num_pid_n_fp8: + offs_fp8_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N_fp8 + c_fp8_in_ptrs = ( + c_fp8_in_ptr + + (offs_k[:, None, None] * stride_c_fp8_in_k) + + (offs_m[None, :, None] * stride_c_fp8_in_m) + + (offs_fp8_n[None, None, :] * stride_c_fp8_in_n) + ) + + if ACTUAL_KSPLIT == MAX_KSPLIT: + c_fp8 = tl.load(c_fp8_in_ptrs) + else: + c_fp8 = tl.load( + c_fp8_in_ptrs, mask=offs_k[:, None, None] < ACTUAL_KSPLIT, other=0.0 + ) + c_fp8 = tl.sum(c_fp8, axis=0) + if ADD_BIAS_FP8: + bias_fp8 = tl.load(bias_fp8_ptr + offs_fp8_n).to(dtype=acc_dtype) + bias_fp8 = tl.broadcast_to(bias_fp8[None, :], (BLOCK_SIZE_M, BLOCK_SIZE_N)) + c_fp8 += bias_fp8 + + c_fp8 = c_fp8.to(c_fp8_out_ptr.type.element_ty) + + c_fp8_out_ptrs = ( + c_fp8_out_ptr + + (offs_m[:, None] * stride_c_fp8_out_m) + + (offs_fp8_n[None, :] * stride_c_fp8_out_n) + ) + + tl.store(c_fp8_out_ptrs, c_fp8) + else: + pid_n -= num_pid_n_fp8 + + offs_bf16_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N_bf16 + c_bf16_in_ptrs = ( + c_bf16_in_ptr + + (offs_k[:, None, None] * stride_c_bf16_in_k) + + (offs_m[None, :, None] * stride_c_bf16_in_m) + + (offs_bf16_n[None, None, :] * stride_c_bf16_in_n) + ) + + if ACTUAL_KSPLIT == MAX_KSPLIT: + c_bf16 = tl.load(c_bf16_in_ptrs) + else: + c_bf16 = tl.load( + c_bf16_in_ptrs, mask=offs_k[:, None, None] < ACTUAL_KSPLIT, other=0.0 + ) + c_bf16 = tl.sum(c_bf16, axis=0) + if ADD_BIAS_BF16: + bias_bf16 = tl.load(bias_bf16_ptr + offs_bf16_n).to(dtype=acc_dtype) + bias_bf16 = tl.broadcast_to( + bias_bf16[None, :], (BLOCK_SIZE_M, BLOCK_SIZE_N) + ) + c_bf16 += bias_bf16 + + c_bf16 = c_bf16.to(c_bf16_out_ptr.type.element_ty) + + c_bf16_out_ptrs = ( + c_bf16_out_ptr + + (offs_m[:, None] * stride_c_bf16_out_m) + + (offs_bf16_n[None, :] * stride_c_bf16_out_n) + ) + c_bf16_mask = (offs_m[:, None] < M) & (offs_bf16_n[None, :] < N_bf16) + tl.store(c_bf16_out_ptrs, c_bf16, mask=c_bf16_mask) + + +@functools.lru_cache(maxsize=1024) +def _get_config( + M: int, + N_fp8: int, + N_bf16: int, + K: int, +): + if not hasattr(_get_config, "_config_dict"): + dev = arch_info.get_device() + _get_config._config_dict = {} + fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json" + with open(fpath, "r") as file: + config = json.load(file) + _get_config._config_dict["default"] = config + + key = f"{N_fp8}_{N_bf16}_{K}" + if key not in _get_config._config_dict.keys(): + dev = arch_info.get_device() + fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16-N8={N_fp8}-N16={N_bf16}-K={K}.json" + if os.path.exists(fpath): + with open(fpath, "r") as file: + config = json.load(file) + _get_config._config_dict[key] = config + else: + key = "default" # fall back to default config + + if M < 16 and "small" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["small"] + elif M < 32 and "small_M16" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["small_M16"] + elif M <= 128: + BLK_M = triton.next_power_of_2(M) + if BLK_M == 32 and "medium_M32" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["medium_M32"] + elif BLK_M == 64 and "medium_M64" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["medium_M64"] + elif BLK_M == 128 and "medium_M128" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["medium_M128"] + elif M <= 256 and "large" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["large"] + else: + BLK_M = triton.next_power_of_2(M) + if f"xlarge_M{BLK_M}" in _get_config._config_dict[key]: + return _get_config._config_dict[key][f"xlarge_M{BLK_M}"] + elif "xlarge" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["xlarge"] + + return _get_config._config_dict[key]["any"] diff --git a/aiter/ops/triton/_triton_kernels/fused_kv_cache.py b/aiter/ops/triton/_triton_kernels/fused_kv_cache.py index 465531253c..cb05bbc4c7 100644 --- a/aiter/ops/triton/_triton_kernels/fused_kv_cache.py +++ b/aiter/ops/triton/_triton_kernels/fused_kv_cache.py @@ -1,9 +1,349 @@ import triton import triton.language as tl -from aiter.ops.triton._triton_kernels.rope import ( - _get_gptj_rotated_x_1D, - _get_neox_rotated_x_1D, -) +from aiter.ops.triton.rope import _get_gptj_rotated_x_1D, _get_neox_rotated_x_1D + + +@triton.jit +def _unit_cat( + x1_ptr, + x2_ptr, + x_out_ptr, + b_in, + b_out, + h, + d1_offs, + d2_offs, + x1_stride_b, + x1_stride_h, + x1_stride_d, + x2_stride_b, + x2_stride_h, + x2_stride_d, + x_out_stride_b, + x_out_stride_h, + x_out_stride_d, + k_scale, + BLOCK_D1: tl.constexpr, +): + x1_offs = b_in * x1_stride_b + h * x1_stride_h + d1_offs * x1_stride_d + x2_offs = b_in * x2_stride_b + h * x2_stride_h + d2_offs * x2_stride_d + x_out_offs = b_out * x_out_stride_b + h * x_out_stride_h + + x1 = tl.load(x1_ptr + x1_offs) + x2 = tl.load(x2_ptr + x2_offs) + + x1 = (x1 / k_scale).to(x_out_ptr.dtype.element_ty) + x2 = (x2 / k_scale).to(x_out_ptr.dtype.element_ty) + tl.store(x_out_ptr + x_out_offs + d1_offs * x_out_stride_d, x1) + tl.store(x_out_ptr + x_out_offs + (d2_offs + BLOCK_D1) * x_out_stride_d, x2) + + +@triton.jit +def _unit_rope_cat( + x_nope_ptr, + x_pe_ptr, + cos, + sin, + x_out_ptr, + b_in, + b_out, + h, + d_nope_offs, + d_pe_offs, + x_nope_stride_b, + x_nope_stride_h, + x_nope_stride_d, + x_pe_stride_b, + x_pe_stride_h, + x_pe_stride_d, + x_out_stride_b, + x_out_stride_h, + x_out_stride_d, + k_scale, + IS_NEOX: tl.constexpr, + BLOCK_D_nope: tl.constexpr, + BLOCK_D_pe: tl.constexpr, + BLOCK_D_HALF_pe: tl.constexpr, +): + x_nope_offs = ( + b_in * x_nope_stride_b + h * x_nope_stride_h + d_nope_offs * x_nope_stride_d + ) + x_pe_offs = b_in * x_pe_stride_b + h * x_pe_stride_h + d_pe_offs * x_pe_stride_d + x_out_offs = b_out * x_out_stride_b + h * x_out_stride_h + + x_nope = tl.load(x_nope_ptr + x_nope_offs) + x_pe = tl.load(x_pe_ptr + x_pe_offs) + + if IS_NEOX: + x_rotated_mask = d_pe_offs < BLOCK_D_HALF_pe + x_pe_rotated = _get_neox_rotated_x_1D( + x_pe, x_rotated_mask, BLOCK_D_pe, BLOCK_D_HALF_pe + ) + else: + x_rotated_mask = d_pe_offs % 2 == 0 + x_pe_rotated = _get_gptj_rotated_x_1D( + x_pe, x_rotated_mask, BLOCK_D_pe, BLOCK_D_HALF_pe + ) + + x_pe = x_pe * cos + x_pe_rotated * sin + x_pe = x_pe / k_scale + x_nope = x_nope / k_scale + x_nope = x_nope.to(x_out_ptr.dtype.element_ty) + x_pe = x_pe.to(x_out_ptr.dtype.element_ty) + + tl.store(x_out_ptr + x_out_offs + d_nope_offs * x_out_stride_d, x_nope) + tl.store(x_out_ptr + x_out_offs + (d_pe_offs + BLOCK_D_nope) * x_out_stride_d, x_pe) + + +@triton.jit +def _fused_qk_rope_cat_and_cache_mla_kernel( + q_nope_ptr, + q_pe_ptr, + k_nope_ptr, + k_pe_ptr, + pos_ptr, + cos_ptr, + sin_ptr, + q_out_ptr, + decode_q_pe_out_ptr, + k_pe_out_ptr, + q_nope_zeros_out_ptr, + kv_cache_ptr, + slot_mapping_ptr, + B, + B_slot, + num_decode_toks_for_zeros, + q_nope_stride_b, + q_nope_stride_h, + q_nope_stride_d, + q_pe_stride_b, + q_pe_stride_h, + q_pe_stride_d, + k_nope_stride_b, + k_nope_stride_h, + k_nope_stride_d, + k_pe_stride_b, + k_pe_stride_h, + k_pe_stride_d, + pos_stride_b, + cos_stride_b, + cos_stride_d, + q_out_stride_b, + q_out_stride_h, + q_out_stride_d, + decode_q_pe_out_stride_b, + decode_q_pe_out_stride_h, + decode_q_pe_out_stride_d, + k_pe_out_stride_b, + k_pe_out_stride_h, + k_pe_out_stride_d, + q_nope_zeros_out_stride_b, + q_nope_zeros_out_stride_h, + q_nope_zeros_out_stride_d, + kv_cache_stride_b, + kv_cache_stride_h, + kv_cache_stride_d, + k_scale_ptr, + QH_PER_KH: tl.constexpr, + QH: tl.constexpr, + KH: tl.constexpr, + REUSE_FREQS_FRONT_PART: tl.constexpr, + IS_NEOX: tl.constexpr, + BLOCK_D_nope: tl.constexpr, + BLOCK_DK_nope: tl.constexpr, + BLOCK_D_pe: tl.constexpr, + BLOCK_D_HALF_pe: tl.constexpr, + OUTPUT_Q_NOPE_ZEROS: tl.constexpr = False, + HAVE_K_SCALE: tl.constexpr = False, +): + pid = tl.program_id(0) + + d_nope_offs = tl.arange(0, BLOCK_D_nope).to(tl.int64) + dk_nope_offs = tl.arange(0, BLOCK_DK_nope).to(tl.int64) + d_pe_offs = tl.arange(0, BLOCK_D_pe).to(tl.int64) + + if pid < B * QH: + pid_b = pid // QH + pid_hq = pid % QH + if REUSE_FREQS_FRONT_PART: + if IS_NEOX: + d_cos_offs = d_pe_offs + d_cos_offs = tl.where( + (d_cos_offs >= BLOCK_D_HALF_pe) & (d_cos_offs < BLOCK_D_pe), + d_cos_offs - BLOCK_D_HALF_pe, + d_cos_offs, + ).to(d_cos_offs.dtype) + # d_cos_mask = d_cos_offs < BLOCK_D_pe + else: + d_cos_offs = d_pe_offs // 2 + # d_cos_mask = d_cos_offs < BLOCK_D_HALF_pe + else: + d_cos_offs = d_pe_offs + # d_cos_mask = d_cos_offs < BLOCK_D_pe + + pos = tl.load(pos_ptr + pid_b * pos_stride_b) + cos_offs = pos * cos_stride_b + d_cos_offs * cos_stride_d + cos = tl.load(cos_ptr + cos_offs) + sin = tl.load(sin_ptr + cos_offs) + + q_nope_ptrs = ( + q_nope_ptr + + pid_b * q_nope_stride_b + + pid_hq * q_nope_stride_h + + d_nope_offs * q_nope_stride_d + ) + q_pe_ptrs = ( + q_pe_ptr + + pid_b * q_pe_stride_b + + pid_hq * q_pe_stride_h + + d_pe_offs * q_pe_stride_d + ) + q_out_ptrs = q_out_ptr + pid_b * q_out_stride_b + pid_hq * q_out_stride_h + q_nope = tl.load(q_nope_ptrs) + q_pe = _unit_rope( + q_pe_ptrs, + cos, + sin, + d_pe_offs, + IS_NEOX, + BLOCK_D_pe, + BLOCK_D_HALF_pe, + ) + tl.store( + q_out_ptrs + d_nope_offs * q_out_stride_d, + q_nope.to(q_out_ptr.dtype.element_ty), + ) + tl.store( + q_out_ptrs + (d_pe_offs + BLOCK_D_nope) * q_out_stride_d, + q_pe.to(q_out_ptr.dtype.element_ty), + ) + + if pid < num_decode_toks_for_zeros * QH: + decode_q_pe_out_ptrs = ( + decode_q_pe_out_ptr + + pid_b * decode_q_pe_out_stride_b + + pid_hq * decode_q_pe_out_stride_h + ) + tl.store( + decode_q_pe_out_ptrs + d_pe_offs * decode_q_pe_out_stride_d, + q_pe.to(decode_q_pe_out_ptr.dtype.element_ty), + ) + + if OUTPUT_Q_NOPE_ZEROS and pid < num_decode_toks_for_zeros * QH: + z = tl.zeros((BLOCK_DK_nope,), dtype=q_nope_zeros_out_ptr.dtype.element_ty) + tl.store( + q_nope_zeros_out_ptr + + pid_b * q_nope_zeros_out_stride_b + + pid_hq * q_nope_zeros_out_stride_h + + dk_nope_offs * q_nope_zeros_out_stride_d, + z, + ) + + if pid_hq % QH_PER_KH == 0: + pid_slot = tl.load(slot_mapping_ptr + pid_b).to(tl.int64) + if pid_slot >= 0: + if HAVE_K_SCALE: + k_scale = tl.load(k_scale_ptr) + else: + k_scale = 1 + + pid_hk = pid_hq // QH_PER_KH + k_nope_ptrs = ( + k_nope_ptr + + pid_b * k_nope_stride_b + + pid_hk * k_nope_stride_h + + dk_nope_offs * k_nope_stride_d + ) + k_pe_ptrs = ( + k_pe_ptr + + pid_b * k_pe_stride_b + + pid_hk * k_pe_stride_h + + d_pe_offs * k_pe_stride_d + ) + k_pe_out_ptrs = ( + k_pe_out_ptr + + pid_b * k_pe_out_stride_b + + pid_hk * k_pe_out_stride_h + + d_pe_offs * k_pe_out_stride_d + ) + kv_cache_ptrs = ( + kv_cache_ptr + + pid_slot * kv_cache_stride_b + + pid_hk * kv_cache_stride_h + ) + k_nope = tl.load(k_nope_ptrs) + k_pe = _unit_rope( + k_pe_ptrs, + cos, + sin, + d_pe_offs, + IS_NEOX, + BLOCK_D_pe, + BLOCK_D_HALF_pe, + ) + tl.store(k_pe_out_ptrs, k_pe.to(k_pe_out_ptr.dtype.element_ty)) + k_scale_rcprl = (1 / k_scale).to(tl.float32) + k_nope = (k_nope.to(tl.float32) * k_scale_rcprl).to( + kv_cache_ptr.dtype.element_ty + ) + k_pe = (k_pe.to(tl.float32) * k_scale_rcprl).to( + kv_cache_ptr.dtype.element_ty + ) + tl.store(kv_cache_ptrs + dk_nope_offs * kv_cache_stride_d, k_nope) + tl.store( + kv_cache_ptrs + (d_pe_offs + BLOCK_DK_nope) * kv_cache_stride_d, + k_pe, + ) + else: + pid = pid - B * QH + B * KH + if pid < B_slot * KH: + pid_b = pid // KH + pid_hk = pid % KH + pid_slot = tl.load(slot_mapping_ptr + pid_b).to(tl.int64) + if pid_slot >= 0: + if HAVE_K_SCALE: + k_scale = tl.load(k_scale_ptr) + else: + k_scale = 1 + + k_nope_ptrs = ( + k_nope_ptr + + pid_b * k_nope_stride_b + + pid_hk * k_nope_stride_h + + dk_nope_offs * k_nope_stride_d + ) + k_pe_ptrs = ( + k_pe_ptr + + pid_b * k_pe_stride_b + + pid_hk * k_pe_stride_h + + d_pe_offs * k_pe_stride_d + ) + k_pe_out_ptrs = ( + k_pe_out_ptr + + pid_b * k_pe_out_stride_b + + pid_hk * k_pe_out_stride_h + + d_pe_offs * k_pe_out_stride_d + ) + kv_cache_ptrs = ( + kv_cache_ptr + + pid_slot * kv_cache_stride_b + + pid_hk * kv_cache_stride_h + ) + k_nope = tl.load(k_nope_ptrs) + k_pe = tl.load(k_pe_ptrs) + tl.store(k_pe_out_ptrs, k_pe.to(k_pe_out_ptr.dtype.element_ty)) + k_scale_rcprl = (1 / k_scale).to(tl.float32) + k_nope = (k_nope.to(tl.float32) * k_scale_rcprl).to( + kv_cache_ptr.dtype.element_ty + ) + k_pe = (k_pe.to(tl.float32) * k_scale_rcprl).to( + kv_cache_ptr.dtype.element_ty + ) + tl.store(kv_cache_ptrs + dk_nope_offs * kv_cache_stride_d, k_nope) + tl.store( + kv_cache_ptrs + (d_pe_offs + BLOCK_DK_nope) * kv_cache_stride_d, + k_pe, + ) @triton.jit @@ -16,24 +356,337 @@ def _unit_rope( BLOCK_D_pe: tl.constexpr, BLOCK_D_HALF_pe: tl.constexpr, ): - x = tl.load(x_ptrs).to(tl.float64) + x_pe = tl.load(x_ptrs) if IS_NEOX: x_rotated_mask = d_pe_offs < BLOCK_D_HALF_pe x_pe_rotated = _get_neox_rotated_x_1D( - x, x_rotated_mask, BLOCK_D_pe, BLOCK_D_HALF_pe + x_pe, x_rotated_mask, BLOCK_D_pe, BLOCK_D_HALF_pe ) else: x_rotated_mask = d_pe_offs % 2 == 0 x_pe_rotated = _get_gptj_rotated_x_1D( - x, x_rotated_mask, BLOCK_D_pe, BLOCK_D_HALF_pe + x_pe, x_rotated_mask, BLOCK_D_pe, BLOCK_D_HALF_pe ) - x_pe = x * cos + x_pe_rotated * sin + x_pe = x_pe * cos + x_pe_rotated * sin return x_pe +@triton.jit +def _fused_qk_rope_reshape_and_cache_kernel( + q_ptr, + k_ptr, + v_ptr, + pos_ptr, + cos_ptr, + sin_ptr, + offs_ptr, + key_cache_ptr, + value_cache_ptr, + slot_mapping_ptr, + q_out_ptr, + k_out_ptr, + zeros_out_ptr, + T, + T_slot, + q_stride_t, + q_stride_h, + q_stride_d, + k_stride_t, + k_stride_h, + k_stride_d, + v_stride_t, + v_stride_h, + v_stride_d, + cos_stride_t, + cos_stride_d, + q_out_stride_t, + q_out_stride_h, + q_out_stride_d, + k_out_stride_t, + k_out_stride_h, + k_out_stride_d, + key_cache_stride_t, + key_cache_stride_h, + key_cache_stride_d, + key_cache_stride_b, + key_cache_stride_x, + value_cache_stride_t, + value_cache_stride_h, + value_cache_stride_d, + value_cache_stride_b, + zeros_out_stride_t, + zeros_out_stride_h, + zeros_out_stride_d, + k_scale_ptr, + v_scale_ptr, + QH_PER_KH: tl.constexpr, + QH: tl.constexpr, + KH: tl.constexpr, + REUSE_FREQS_FRONT_PART: tl.constexpr, + IS_NEOX: tl.constexpr, + BLOCK_D_pe: tl.constexpr, + BLOCK_D_HALF_pe: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + X_SIZE: tl.constexpr, + FLASH_LAYOUT: tl.constexpr, + HAVE_POS: tl.constexpr = False, + HAVE_K_SCALE: tl.constexpr = False, + HAVE_V_SCALE: tl.constexpr = False, + HAVE_ZEROS: tl.constexpr = False, +): + + tl.assume(q_stride_t >= 0) + tl.assume(q_stride_h >= 0) + tl.assume(q_stride_d >= 0) + tl.assume(k_stride_t >= 0) + tl.assume(k_stride_h >= 0) + tl.assume(k_stride_d >= 0) + tl.assume(v_stride_t >= 0) + tl.assume(v_stride_h >= 0) + tl.assume(v_stride_d >= 0) + tl.assume(cos_stride_t >= 0) + tl.assume(cos_stride_d >= 0) + tl.assume(q_out_stride_t >= 0) + tl.assume(q_out_stride_h >= 0) + tl.assume(q_out_stride_d >= 0) + tl.assume(k_out_stride_t >= 0) + tl.assume(k_out_stride_h >= 0) + tl.assume(k_out_stride_d >= 0) + tl.assume(key_cache_stride_t >= 0) + tl.assume(key_cache_stride_h >= 0) + tl.assume(key_cache_stride_d >= 0) + tl.assume(key_cache_stride_b >= 0) + tl.assume(key_cache_stride_x >= 0) + tl.assume(value_cache_stride_t >= 0) + tl.assume(value_cache_stride_h >= 0) + tl.assume(value_cache_stride_d >= 0) + tl.assume(value_cache_stride_b >= 0) + tl.assume(zeros_out_stride_t >= 0) + tl.assume(zeros_out_stride_h >= 0) + tl.assume(zeros_out_stride_d >= 0) + + pid = tl.program_id(0) + tl.assume(pid >= 0) + + d_pe_offs = tl.arange(0, BLOCK_D_pe).to(tl.int64) + + if pid < T * QH: + pid_t = pid // QH + pid_hq = pid % QH + if REUSE_FREQS_FRONT_PART: + if IS_NEOX: + d_cos_offs = d_pe_offs + d_cos_offs = tl.where( + (d_cos_offs >= BLOCK_D_HALF_pe) & (d_cos_offs < BLOCK_D_pe), + d_cos_offs - BLOCK_D_HALF_pe, + d_cos_offs, + ).to(d_cos_offs.dtype) + # d_cos_mask = d_cos_offs < BLOCK_D_pe + else: + d_cos_offs = d_pe_offs // 2 + # d_cos_mask = d_cos_offs < BLOCK_D_HALF_pe + else: + d_cos_offs = d_pe_offs + # d_cos_mask = d_cos_offs < BLOCK_D_pe + + pos = tl.load(pos_ptr + pid_t) + if HAVE_POS: + offset = tl.load(offs_ptr + pid_t) + pos = pos + offset + cos_offs = pos * cos_stride_t + d_cos_offs * cos_stride_d + cos = tl.load(cos_ptr + cos_offs) + sin = tl.load(sin_ptr + cos_offs) + + q_ptrs = ( + q_ptr + pid_t * q_stride_t + pid_hq * q_stride_h + d_pe_offs * q_stride_d + ) + q_pe = _unit_rope( + q_ptrs, + cos, + sin, + d_pe_offs, + IS_NEOX, + BLOCK_D_pe, + BLOCK_D_HALF_pe, + ) + q_out_ptrs = ( + q_out_ptr + + pid_t * q_out_stride_t + + pid_hq * q_out_stride_h + + d_pe_offs * q_out_stride_d + ) + tl.store(q_out_ptrs, q_pe.to(q_out_ptr.dtype.element_ty)) + + if HAVE_ZEROS: + z = tl.zeros((BLOCK_D_pe,), dtype=zeros_out_ptr.dtype.element_ty) + zeros_out_ptrs = ( + zeros_out_ptr + + pid_t * zeros_out_stride_t + + pid_hq * zeros_out_stride_h + + d_pe_offs * zeros_out_stride_d + ) + tl.store(zeros_out_ptrs, z) + + if pid_hq % QH_PER_KH == 0: + pid_slot = tl.load(slot_mapping_ptr + pid_t).to(tl.int64) + if pid_slot >= 0: + pid_t_slot = pid_slot // BLOCK_SIZE + pid_b = pid_slot % BLOCK_SIZE + pid_hk = pid_hq // QH_PER_KH + if HAVE_K_SCALE: + k_scale = tl.load(k_scale_ptr) + else: + k_scale = 1 + k_ptrs = ( + k_ptr + + pid_t * k_stride_t + + pid_hk * k_stride_h + + d_pe_offs * k_stride_d + ) + k_pe = _unit_rope( + k_ptrs, + cos, + sin, + d_pe_offs, + IS_NEOX, + BLOCK_D_pe, + BLOCK_D_HALF_pe, + ) + + k_out_ptrs = ( + k_out_ptr + + pid_t * k_out_stride_t + + pid_hk * k_out_stride_h + + d_pe_offs * k_out_stride_d + ) + tl.store(k_out_ptrs, k_pe.to(k_out_ptr.dtype.element_ty)) + + k_scale_rcprl = 1 / k_scale + k_pe = k_pe * k_scale_rcprl + + if FLASH_LAYOUT: + k_out_ptrs = ( + key_cache_ptr + + pid_t_slot * key_cache_stride_t + + pid_b * key_cache_stride_b + + pid_hk * key_cache_stride_h + + d_pe_offs * key_cache_stride_d + ) + else: + k_pe = tl.reshape(k_pe, (BLOCK_D_pe // X_SIZE, X_SIZE)) + dx_offs = tl.arange(0, BLOCK_D_pe // X_SIZE).to(tl.int64) + x_offs = tl.arange(0, X_SIZE).to(tl.int64) + k_out_ptrs = ( + key_cache_ptr + + pid_t_slot * key_cache_stride_t + + pid_hk * key_cache_stride_h + + dx_offs[:, None] * key_cache_stride_d + + pid_b * key_cache_stride_b + + x_offs[None, :] * key_cache_stride_x + ) + + tl.store(k_out_ptrs, k_pe.to(key_cache_ptr.dtype.element_ty)) + + v_ptrs = ( + v_ptr + + pid_t * v_stride_t + + pid_hk * v_stride_h + + d_pe_offs * v_stride_d + ) + if HAVE_V_SCALE: + v_scale = tl.load(v_scale_ptr) + else: + v_scale = 1 + v_scale_rcprl = 1 / v_scale + v = tl.load(v_ptrs) * v_scale_rcprl + v_out_ptrs = ( + value_cache_ptr + + pid_t_slot * value_cache_stride_t + + pid_hk * value_cache_stride_h + + d_pe_offs.to(tl.int64) * value_cache_stride_d + + pid_b * value_cache_stride_b + ) + tl.store(v_out_ptrs, v.to(value_cache_ptr.dtype.element_ty)) + else: + pid = pid - T * QH + T * KH + if pid < T_slot * KH: + pid_t = pid // KH + pid_hk = pid % KH + pid_slot = tl.load(slot_mapping_ptr + pid_t).to(tl.int64) + if pid_slot >= 0: + pid_t_slot = pid_slot // BLOCK_SIZE + pid_b = pid_slot % BLOCK_SIZE + if HAVE_K_SCALE: + k_scale = tl.load(k_scale_ptr) + else: + k_scale = 1 + k_ptrs = ( + k_ptr + + pid_t * k_stride_t + + pid_hk * k_stride_h + + d_pe_offs * k_stride_d + ) + + k_pe = tl.load(k_ptrs) + + k_out_ptrs = ( + k_out_ptr + + pid_t * k_out_stride_t + + pid_hk * k_out_stride_h + + d_pe_offs * k_out_stride_d + ) + tl.store(k_out_ptrs, k_pe.to(k_out_ptr.dtype.element_ty)) + + k_scale_rcprl = 1 / k_scale + k_pe = k_pe * k_scale_rcprl + + if FLASH_LAYOUT: + k_out_ptrs = ( + key_cache_ptr + + pid_t_slot * key_cache_stride_t + + d_pe_offs * key_cache_stride_d + + pid_b * key_cache_stride_b + + pid_hk * key_cache_stride_h + ) + else: + k_pe = tl.reshape(k_pe, (BLOCK_D_pe // X_SIZE, X_SIZE)) + dx_offs = tl.arange(0, BLOCK_D_pe // X_SIZE).to(tl.int64) + x_offs = tl.arange(0, X_SIZE).to(tl.int64) + k_out_ptrs = ( + key_cache_ptr + + pid_t_slot * key_cache_stride_t + + pid_hk * key_cache_stride_h + + dx_offs[:, None] * key_cache_stride_d + + pid_b * key_cache_stride_b + + x_offs[None, :] * key_cache_stride_x + ) + tl.store(k_out_ptrs, k_pe.to(key_cache_ptr.dtype.element_ty)) + + v_ptrs = ( + v_ptr + + pid_t * v_stride_t + + pid_hk * v_stride_h + + d_pe_offs * v_stride_d + ) + if HAVE_V_SCALE: + v_scale = tl.load(v_scale_ptr) + else: + v_scale = 1 + v_scale_rcprl = 1 / v_scale + v = tl.load(v_ptrs) * v_scale_rcprl + v_out_ptrs = ( + value_cache_ptr + + pid_t_slot * value_cache_stride_t + + pid_hk * value_cache_stride_h + + d_pe_offs * value_cache_stride_d + + pid_b * value_cache_stride_b + ) + tl.store(v_out_ptrs, v.to(value_cache_ptr.dtype.element_ty)) + + @triton.jit def _fused_qk_rope_cosine_cache_llama_kernel( q_ptr, @@ -90,7 +743,7 @@ def _fused_qk_rope_cosine_cache_llama_kernel( ): pid = tl.program_id(0) - d_pe_offs = tl.arange(0, BLOCK_D_pe) + d_pe_offs = tl.arange(0, BLOCK_D_pe).to(tl.int64) if pid < T * QH: pid_t = pid // QH @@ -177,8 +830,8 @@ def _fused_qk_rope_cosine_cache_llama_kernel( ) else: k_pe = tl.reshape(k_pe, (BLOCK_D_pe // X_SIZE, X_SIZE)) - dx_offs = tl.arange(0, BLOCK_D_pe // X_SIZE) - x_offs = tl.arange(0, X_SIZE) + dx_offs = tl.arange(0, BLOCK_D_pe // X_SIZE).to(tl.int64) + x_offs = tl.arange(0, X_SIZE).to(tl.int64) k_out_ptrs = ( key_cache_ptr + pid_t_slot * key_cache_stride_t @@ -245,8 +898,8 @@ def _fused_qk_rope_cosine_cache_llama_kernel( ) else: k_pe = tl.reshape(k_pe, (BLOCK_D_pe // X_SIZE, X_SIZE)) - dx_offs = tl.arange(0, BLOCK_D_pe // X_SIZE) - x_offs = tl.arange(0, X_SIZE) + dx_offs = tl.arange(0, BLOCK_D_pe // X_SIZE).to(tl.int64) + x_offs = tl.arange(0, X_SIZE).to(tl.int64) k_out_ptrs = ( key_cache_ptr + pid_t_slot * key_cache_stride_t diff --git a/aiter/ops/triton/_triton_kernels/fused_qkv_split_qk_rope.py b/aiter/ops/triton/_triton_kernels/fused_qkv_split_qk_rope.py new file mode 100644 index 0000000000..cf2f1bb9ef --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/fused_qkv_split_qk_rope.py @@ -0,0 +1,187 @@ +import triton +import triton.language as tl +from aiter.ops.triton.rope import _get_gptj_rotated_x, _get_neox_rotated_x + + +@triton.jit +def _fused_qkv_split_qk_rope_kernel( + qkv_ptr, + cos_ptr, + sin_ptr, + pos_ptr, + off_ptr, + q_ptr, + k_ptr, + v_ptr, + T, + stride_qkv_t, + stride_qkv_d, + stride_cos_t, + stride_cos_d, + stride_pos_t, + stride_q_t, + stride_q_h, + stride_q_d, + stride_kv_t, + stride_kv_h, + stride_kv_d, + HAVE_NOPE: tl.constexpr, + NOPE_FIRST: tl.constexpr, + REUSE_FREQS_FRONT_PART: tl.constexpr, + IS_NEOX: tl.constexpr, + HAVE_POS: tl.constexpr, + HAVE_OFFS: tl.constexpr, + QH: tl.constexpr, + KVH: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_D_HALF: tl.constexpr, +): + tl.assume(stride_qkv_t > 0) + tl.assume(stride_qkv_d > 0) + tl.assume(stride_cos_t > 0) + tl.assume(stride_cos_d > 0) + tl.assume(stride_pos_t > 0) + tl.assume(stride_q_t > 0) + tl.assume(stride_q_h > 0) + tl.assume(stride_q_d > 0) + tl.assume(stride_kv_t > 0) + tl.assume(stride_kv_h > 0) + tl.assume(stride_kv_d > 0) + + pid_t = tl.program_id(0) + hq = tl.program_id(1) + + tl.assume(pid_t >= 0) + tl.assume(hq >= 0) + + t_offs = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) + d_offs = tl.arange(0, BLOCK_D) + t_mask = t_offs < T + + if HAVE_POS: + pos_offs = t_offs * stride_pos_t + pos = tl.load(pos_ptr + pos_offs, mask=t_mask) + if HAVE_OFFS: + offset = tl.load(off_ptr + pos_offs, mask=t_mask) + t_cos_offs = pos + offset + else: + t_cos_offs = pos + else: + t_cos_offs = t_offs + + if REUSE_FREQS_FRONT_PART: + if IS_NEOX: + d_cos_offs = d_offs + d_cos_offs = tl.where( + (d_cos_offs < BLOCK_D_HALF), + d_cos_offs, + d_cos_offs - BLOCK_D_HALF, + ).to(d_cos_offs.dtype) + d_cos_mask = d_cos_offs < BLOCK_D_HALF + else: + d_cos_offs = tl.arange(0, BLOCK_D) // 2 + d_cos_mask = d_cos_offs < BLOCK_D_HALF + else: + d_cos_offs = d_offs + d_cos_mask = d_cos_offs < BLOCK_D + + cos_mask = t_mask[:, None] & d_cos_mask[None, :] + cos_offs = t_cos_offs[:, None] * stride_cos_t + d_cos_offs[None, :] * stride_cos_d + cos = tl.load(cos_ptr + cos_offs, mask=cos_mask) + sin = tl.load(sin_ptr + cos_offs, mask=cos_mask) + + nope_offs = 0 + if HAVE_NOPE and NOPE_FIRST: + nope_offs = BLOCK_D + + offs_nope_ratio = 1 + if HAVE_NOPE: + offs_nope_ratio = 2 + + x_mask = t_mask[:, None] & (d_offs < BLOCK_D)[None, :] + + if IS_NEOX: + qk_rotated_mask = (d_offs < BLOCK_D_HALF)[None, :] + else: + qk_rotated_mask = (d_offs % 2 == 0)[None, :] + + H_OFFS_SIZE = hq * BLOCK_D + d_offs += nope_offs + q_in_offs = ( + t_offs[:, None] * stride_qkv_t + + (H_OFFS_SIZE * offs_nope_ratio + d_offs)[None, :] * stride_qkv_d + ) + q = tl.load(qkv_ptr + q_in_offs, mask=x_mask) + + if IS_NEOX: + q_rotated = _get_neox_rotated_x( + q, qk_rotated_mask, BLOCK_T, BLOCK_D, BLOCK_D_HALF + ) + else: + q_rotated = _get_gptj_rotated_x( + q, qk_rotated_mask, BLOCK_T, BLOCK_D, BLOCK_D_HALF + ) + + q_out_offs = ( + t_offs[:, None] * stride_q_t + d_offs[None, :] * stride_q_d + hq * stride_q_h + ) + q = q * cos + q_rotated * sin + q = q.to(q_ptr.dtype.element_ty) + tl.store(q_ptr + q_out_offs, q, mask=x_mask) + + if HAVE_NOPE: + if NOPE_FIRST: + q = tl.load(qkv_ptr + q_in_offs - BLOCK_D * stride_qkv_d, mask=x_mask) + tl.store(q_ptr + q_out_offs - BLOCK_D * stride_q_d, q, mask=x_mask) + else: + q = tl.load(qkv_ptr + q_in_offs + BLOCK_D * stride_qkv_d, mask=x_mask) + tl.store(q_ptr + q_out_offs + BLOCK_D * stride_q_d, q, mask=x_mask) + + if hq < KVH: + Q_SIZE = QH * BLOCK_D + KV_SIZE = KVH * BLOCK_D + k_in_offs = ( + t_offs[:, None] * stride_qkv_t + + ((Q_SIZE + H_OFFS_SIZE) * offs_nope_ratio + d_offs)[None, :] + * stride_qkv_d + ) + v_in_offs = ( + t_offs[:, None] * stride_qkv_t + + ((Q_SIZE + KV_SIZE + H_OFFS_SIZE) * offs_nope_ratio + d_offs)[None, :] + * stride_qkv_d + ) + k = tl.load(qkv_ptr + k_in_offs, mask=x_mask) + v = tl.load(qkv_ptr + v_in_offs, mask=x_mask) + + if IS_NEOX: + k_rotated = _get_neox_rotated_x( + k, qk_rotated_mask, BLOCK_T, BLOCK_D, BLOCK_D_HALF + ) + else: + k_rotated = _get_gptj_rotated_x( + k, qk_rotated_mask, BLOCK_T, BLOCK_D, BLOCK_D_HALF + ) + + kv_out_offs = ( + t_offs[:, None] * stride_kv_t + + d_offs[None, :] * stride_kv_d + + hq * stride_kv_h + ) + k = k * cos + k_rotated * sin + k = k.to(k_ptr.dtype.element_ty) + tl.store(k_ptr + kv_out_offs, k, mask=x_mask) + v = v.to(v_ptr.dtype.element_ty) + tl.store(v_ptr + kv_out_offs, v, mask=x_mask) + + if HAVE_NOPE: + if NOPE_FIRST: + k = tl.load(qkv_ptr + k_in_offs - BLOCK_D * stride_qkv_d, mask=x_mask) + tl.store(k_ptr + kv_out_offs - BLOCK_D * stride_kv_d, k, mask=x_mask) + v = tl.load(qkv_ptr + v_in_offs - BLOCK_D * stride_qkv_d, mask=x_mask) + tl.store(v_ptr + kv_out_offs - BLOCK_D * stride_kv_d, v, mask=x_mask) + else: + k = tl.load(qkv_ptr + k_in_offs + BLOCK_D * stride_qkv_d, mask=x_mask) + tl.store(k_ptr + kv_out_offs + BLOCK_D * stride_kv_d, k, mask=x_mask) + v = tl.load(qkv_ptr + v_in_offs + BLOCK_D * stride_qkv_d, mask=x_mask) + tl.store(v_ptr + kv_out_offs + BLOCK_D * stride_kv_d, v, mask=x_mask) diff --git a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py index 49f2e09e06..33281106eb 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py @@ -20,6 +20,7 @@ def _gemm_a16_w16_kernel( a_ptr, b_ptr, + bias_ptr, c_ptr, M, N, @@ -43,6 +44,8 @@ def _gemm_a16_w16_kernel( cache_modifier: tl.constexpr, activation: tl.constexpr, use_activation: tl.constexpr, + ADD_BIAS: tl.constexpr, + SKIP_REDUCE: tl.constexpr, ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -92,7 +95,16 @@ def _gemm_a16_w16_kernel( ) acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32 - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + if ADD_BIAS: + if NUM_KSPLIT == 1 or (SKIP_REDUCE and pid_k == 0): + accumulator = tl.load(bias_ptr + offs_bn).to(dtype=acc_dtype) + accumulator = tl.broadcast_to( + accumulator[None, :], (BLOCK_SIZE_M, BLOCK_SIZE_N) + ) + else: + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + else: + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) split_k_end = tl.minimum(split_k_start + SPLITK_BLOCK_SIZE, K) k_span = split_k_end - split_k_start @@ -138,6 +150,7 @@ def _gemm_a16_w16_kernel( @triton.jit def _gemm_a16w16_reduce_kernel( + bias_ptr, c_in_ptr, c_out_ptr, M, @@ -153,6 +166,7 @@ def _gemm_a16w16_reduce_kernel( MAX_KSPLIT: tl.constexpr, activation: tl.constexpr, use_activation: tl.constexpr, + ADD_BIAS: tl.constexpr, ): tl.assume(stride_c_in_k > 0) @@ -182,6 +196,11 @@ def _gemm_a16w16_reduce_kernel( else: c = tl.load(c_in_ptrs, mask=offs_k[:, None, None] < ACTUAL_KSPLIT, other=0.0) c = tl.sum(c, axis=0) + acc_dtype = tl.float32 if c_in_ptr.type.element_ty != tl.int8 else tl.int32 + if ADD_BIAS: + bias = tl.load(bias_ptr + offs_n).to(dtype=acc_dtype) + bias = tl.broadcast_to(bias[None, :], (BLOCK_SIZE_M, BLOCK_SIZE_N)) + c += bias if use_activation: c = activation(c) @@ -221,13 +240,13 @@ def _get_config( else: key = "default" # fall back to default config - bounds = [64, 128, 256, 512, 2048] + bounds = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] for bound in bounds: if M <= bound and f"M_LEQ_{bound}" in _get_config._config_dict[key]: temp_config = _get_config._config_dict[key][f"M_LEQ_{bound}"] break else: - temp_config = _get_config._config_dict[key]["M_GEQ_4096"] + temp_config = _get_config._config_dict[key]["any"] # Copy to avoid mutating the cached config chosen_config = dict(temp_config) diff --git a/aiter/ops/triton/_triton_kernels/gemm_a8w8_blockscale.py b/aiter/ops/triton/_triton_kernels/gemm_a8w8_blockscale.py index 224170edf4..9343d40787 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a8w8_blockscale.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a8w8_blockscale.py @@ -276,4 +276,23 @@ def _get_config( else: key = "default" # fall back to default config + if M < 32 and "small" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["small"] + elif M <= 128: + BLK_M = triton.next_power_of_2(M) + if BLK_M == 32 and "medium_M32" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["medium_M32"] + elif BLK_M == 64 and "medium_M64" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["medium_M64"] + elif BLK_M == 128 and "medium_M128" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["medium_M128"] + elif M <= 256 and "large" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["large"] + else: + BLK_M = triton.next_power_of_2(M) + if f"xlarge_M{BLK_M}" in _get_config._config_dict[key]: + return _get_config._config_dict[key][f"xlarge_M{BLK_M}"] + elif "xlarge" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["xlarge"] + return _get_config._config_dict[key]["any"] diff --git a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py index e092b3fd0b..dc94959376 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm_afp4wfp4.py @@ -16,8 +16,6 @@ "EVEN_K": lambda args: (args["K"] % (args["BLOCK_SIZE_K"] // 2) == 0) and (args["SPLITK_BLOCK_SIZE"] % args["BLOCK_SIZE_K"] == 0) and (args["K"] % (args["SPLITK_BLOCK_SIZE"] // 2) == 0), - "GRID_MN": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"]) - * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) @triton.jit @@ -49,7 +47,6 @@ def _gemm_afp4_wfp4_kernel( NUM_KSPLIT: tl.constexpr, SPLITK_BLOCK_SIZE: tl.constexpr, EVEN_K: tl.constexpr, - GRID_MN: tl.constexpr, cache_modifier: tl.constexpr, ): """Kernel for computing the matmul C = A x B. @@ -69,6 +66,8 @@ def _gemm_afp4_wfp4_kernel( tl.assume(stride_bsk > 0) tl.assume(stride_bsn > 0) + GRID_MN = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. @@ -124,7 +123,7 @@ def _gemm_afp4_wfp4_kernel( for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): a_scales = tl.load(a_scale_ptrs) - b_scales = tl.load(b_scale_ptrs) + b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. @@ -170,8 +169,6 @@ def _gemm_afp4_wfp4_kernel( "EVEN_K": lambda args: (args["K"] % (args["BLOCK_SIZE_K"] // 2) == 0) and (args["SPLITK_BLOCK_SIZE"] % args["BLOCK_SIZE_K"] == 0) and (args["K"] % (args["SPLITK_BLOCK_SIZE"] // 2) == 0), - "GRID_MN": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"]) - * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) @triton.jit @@ -203,7 +200,6 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales( NUM_KSPLIT: tl.constexpr, SPLITK_BLOCK_SIZE: tl.constexpr, EVEN_K: tl.constexpr, - GRID_MN: tl.constexpr, cache_modifier: tl.constexpr, ): """Kernel for computing the matmul C = A x B. @@ -223,6 +219,8 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales( tl.assume(stride_bsk > 0) tl.assume(stride_bsn > 0) + GRID_MN = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. @@ -297,14 +295,36 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): - a_scales = tl.load(a_scale_ptrs) - b_scales = tl.load(b_scale_ptrs, cache_modifier=cache_modifier) - if BLOCK_SIZE_M >= 32: - a_scales = tl.reshape( - a_scales, (BLOCK_SIZE_M, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + if BLOCK_SIZE_M < 32: + a_scales = tl.load(a_scale_ptrs) + else: + a_scales = ( + tl.load(a_scale_ptrs) + .reshape( + BLOCK_SIZE_M // 32, + BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8, + 4, + 16, + 2, + 2, + 1, + ) + .permute(0, 5, 3, 1, 4, 2, 6) + .reshape(BLOCK_SIZE_M, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + b_scales = ( + tl.load(b_scale_ptrs, cache_modifier=cache_modifier) + .reshape( + BLOCK_SIZE_N // 32, + BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8, + 4, + 16, + 2, + 2, + 1, ) - b_scales = tl.reshape( - b_scales, (BLOCK_SIZE_N, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + .permute(0, 5, 3, 1, 4, 2, 6) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // SCALE_GROUP_SIZE) ) # Load the next block of A and B, generate a mask by checking the K dimension. @@ -346,6 +366,221 @@ def _gemm_afp4_wfp4_kernel_preshuffled_scales( tl.store(c_ptrs, c, mask=c_mask, cache_modifier=".wt") +@triton.heuristics( + { + "EVEN_K": lambda args: (args["K"] % (args["BLOCK_SIZE_K"] // 2) == 0) + and (args["SPLITK_BLOCK_SIZE"] % args["BLOCK_SIZE_K"] == 0) + and (args["K"] % (args["SPLITK_BLOCK_SIZE"] // 2) == 0), + } +) +@triton.jit +def _gemm_afp4_wfp4_kernel_preshuffled_weight_scales( + a_ptr, + b_ptr, + c_ptr, + a_scales_ptr, + b_scales_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bn, + stride_bk, + stride_ck, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bsn, + stride_bsk, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_KSPLIT: tl.constexpr, + SPLITK_BLOCK_SIZE: tl.constexpr, + EVEN_K: tl.constexpr, + cache_modifier: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A and B inputs are in the microscale fp4 (mxfp4) format. + A_scales and B_scales are in e8m0 format. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + tl.assume(stride_asm > 0) + tl.assume(stride_ask > 0) + tl.assume(stride_bsk > 0) + tl.assume(stride_bsn > 0) + + GRID_MN = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid_unified = tl.program_id(axis=0) + pid_unified = remap_xcd(pid_unified, GRID_MN * NUM_KSPLIT, NUM_XCDS=8) + pid_k = pid_unified % NUM_KSPLIT + pid = pid_unified // NUM_KSPLIT + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + if NUM_KSPLIT == 1: + pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M) + else: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + # We assume 32 elements along K share the same scale. + SCALE_GROUP_SIZE: tl.constexpr = 32 + + if (pid_k * SPLITK_BLOCK_SIZE // 2) < K: + + num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE // 2, BLOCK_SIZE_K // 2) + + # Create pointers for first block of A and B input matrices + # The BLOCK sizes are of the elements and in fp4 we pack 2 per uint8 container. + offs_k = tl.arange(0, BLOCK_SIZE_K // 2) + offs_k_shuffle_arr = tl.arange(0, (BLOCK_SIZE_K // 2) * 16) + offs_k_split = pid_k * (SPLITK_BLOCK_SIZE // 2) + offs_k + offs_k_shuffle = pid_k * (SPLITK_BLOCK_SIZE // 2) * 16 + offs_k_shuffle_arr + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * (BLOCK_SIZE_N // 16) + tl.arange(0, BLOCK_SIZE_N // 16)) % N + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak + ) + b_ptrs = b_ptr + ( + # offs_k_split[:, None] * stride_bk + offs_bn[None, :] * stride_bn + offs_bn[:, None] * stride_bn + + offs_k_shuffle[None, :] * stride_bk + ) + # Create pointers for the first block of A and B scales + + offs_asn = ( + pid_n * (BLOCK_SIZE_N // 32) + tl.arange(0, (BLOCK_SIZE_N // 32)) + ) % N + offs_ks = (pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE) * 32) + tl.arange( + 0, BLOCK_SIZE_K // SCALE_GROUP_SIZE * 32 + ) + # B scales are N x K even though B operand is K x N. + b_scale_ptrs = ( + b_scales_ptr + + offs_asn[:, None] * stride_bsn + + offs_ks[None, :] * stride_bsk + ) + + if BLOCK_SIZE_M < 32: + offs_ks_non_shufl = ( + pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE) + ) + tl.arange(0, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + a_scale_ptrs = ( + a_scales_ptr + + offs_am[:, None] * stride_asm + + offs_ks_non_shufl[None, :] * stride_ask + ) + else: + offs_asm = ( + pid_m * (BLOCK_SIZE_M // 32) + tl.arange(0, (BLOCK_SIZE_M // 32)) + ) % M + a_scale_ptrs = ( + a_scales_ptr + + offs_asm[:, None] * stride_asm + + offs_ks[None, :] * stride_ask + ) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter): + if BLOCK_SIZE_M < 32: + a_scales = tl.load(a_scale_ptrs) + else: + a_scales = ( + tl.load(a_scale_ptrs) + .reshape( + BLOCK_SIZE_M // 32, + BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8, + 4, + 16, + 2, + 2, + 1, + ) + .permute(0, 5, 3, 1, 4, 2, 6) + .reshape(BLOCK_SIZE_M, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + + b_scales = ( + tl.load(b_scale_ptrs, cache_modifier=cache_modifier) + .reshape( + BLOCK_SIZE_N // 32, + BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8, + 4, + 16, + 2, + 2, + 1, + ) + .permute(0, 5, 3, 1, 4, 2, 6) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // SCALE_GROUP_SIZE) + ) + + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs, cache_modifier=cache_modifier) + + b = ( + b.reshape( + 1, + BLOCK_SIZE_N // 16, + BLOCK_SIZE_K // 64, + 2, + 16, + 16, + ) + .permute(0, 1, 4, 2, 3, 5) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // 2) + .trans(1, 0) + ) + + accumulator += tl.dot_scaled(a, a_scales, "e2m1", b, b_scales, "e2m1") + + # Advance the ptrs to the next K block. + a_ptrs += (BLOCK_SIZE_K // 2) * stride_ak + b_ptrs += (BLOCK_SIZE_K // 2) * 16 * stride_bk + if BLOCK_SIZE_M < 32: + a_scale_ptrs += (BLOCK_SIZE_K // SCALE_GROUP_SIZE) * stride_ask + else: + a_scale_ptrs += BLOCK_SIZE_K * stride_ask + b_scale_ptrs += BLOCK_SIZE_K * stride_bsk + + c = accumulator.to(c_ptr.type.element_ty) + + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + c_ptrs = ( + c_ptr + + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :] + + pid_k * stride_ck + ) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask, cache_modifier=".wt") + + @triton.jit def _gemm_afp4_wfp4_reduce_kernel( c_in_ptr, @@ -398,29 +633,34 @@ def _get_config( M: int, N: int, K: int, + shuffle: bool = False, ): - if not hasattr(_get_config, "_config_dict"): + shuffle_filename_suffix = "" if not shuffle else "_PRESHUFFLED" + if not hasattr(_get_config, "_config_dict") or not hasattr( + _get_config._config_dict, f"default{shuffle_filename_suffix}" + ): dev = arch_info.get_device() _get_config._config_dict = {} - fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-AFP4WFP4.json" + fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-AFP4WFP4{shuffle_filename_suffix}.json" with open(fpath, "r") as file: config = json.load(file) - _get_config._config_dict["default"] = config + _get_config._config_dict[f"default{shuffle_filename_suffix}"] = config - key = f"{N}_{K}" + key = f"{N}_{K}{shuffle_filename_suffix}" if key not in _get_config._config_dict.keys(): dev = arch_info.get_device() - fpath = ( - f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-AFP4WFP4-N={N}-K={2*K}.json" - ) + fpath = f"{AITER_TRITON_CONFIGS_PATH}/gemm/{dev}-GEMM-AFP4WFP4{shuffle_filename_suffix}-N={N}-K={2*K}.json" if os.path.exists(fpath): with open(fpath, "r") as file: config = json.load(file) _get_config._config_dict[key] = config else: - key = "default" # fall back to default config + key = f"default{shuffle_filename_suffix}" # fall back to default config if M < 32: + BLK_M = triton.next_power_of_2(M) + if BLK_M >= 16 and "small_M16" in _get_config._config_dict[key]: + return _get_config._config_dict[key]["small_M16"] return _get_config._config_dict[key]["small"] elif M <= 128: BLK_M = triton.next_power_of_2(M) diff --git a/aiter/ops/triton/_triton_kernels/split_qkv.py b/aiter/ops/triton/_triton_kernels/split_qkv.py new file mode 100644 index 0000000000..042b78d900 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/split_qkv.py @@ -0,0 +1,38 @@ +import triton +import triton.language as tl + + +@triton.jit +def _split_qkv_kernel( + qkv_ptr, + q_ptr, + k_ptr, + v_ptr, + qkv_stride_0, + q_stride_0, + k_stride_0, + v_stride_0, + Q_SIZE: tl.constexpr, + KV_SIZE: tl.constexpr, +): + + wid = tl.program_id(0) + + q_offs = tl.arange(0, Q_SIZE) + kv_offs = tl.arange(0, KV_SIZE) + + q_load_ptrs = qkv_ptr + (wid * qkv_stride_0) + q_offs + k_load_ptrs = qkv_ptr + (wid * qkv_stride_0) + Q_SIZE + kv_offs + v_load_ptrs = qkv_ptr + (wid * qkv_stride_0) + Q_SIZE + KV_SIZE + kv_offs + + q = tl.load(q_load_ptrs) + k = tl.load(k_load_ptrs) + v = tl.load(v_load_ptrs) + + q_store_ptrs = q_ptr + (wid * q_stride_0) + q_offs + k_store_ptrs = k_ptr + (wid * k_stride_0) + kv_offs + v_store_ptrs = v_ptr + (wid * v_stride_0) + kv_offs + + tl.store(q_store_ptrs, q) + tl.store(k_store_ptrs, k) + tl.store(v_store_ptrs, v) diff --git a/aiter/ops/triton/_triton_kernels/unified_attention.py b/aiter/ops/triton/_triton_kernels/unified_attention.py new file mode 100644 index 0000000000..811d3405ef --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/unified_attention.py @@ -0,0 +1,784 @@ +# The kernels in this file are adapted from vLLM: +# https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_unified_attention.py +import triton +import triton.language as tl +import torch +from aiter.ops.triton.utils.types import e4m3_dtype +import math + +float8_info = torch.finfo(e4m3_dtype) + + +@triton.jit +def fast_exp(x): + RCP_LN2: tl.constexpr = 1.4426950408889634 + return tl.math.exp2(x * RCP_LN2) + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def apply_softcap(S, x): + Sdiv = S / x + p1 = tl.math.exp2(Sdiv) + p2 = tl.math.exp2(-Sdiv) + return x * (p1 - p2) / (p1 + p2) + + +@triton.jit +def find_seq_idx( + query_start_len_ptr, + target_idx, + num_seqs, + BLOCK_Q: tl.constexpr, + use_q_block_mode: tl.constexpr, +): + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + + return left - 1 + + +@triton.jit +def kernel_unified_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale: tl.constexpr, # float32 + k_scale, # float32 + v_scale, # float32 + out_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int must be power of 2 + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, + ALL_DECODE: tl.constexpr = False, # bool +): + kv_head_idx = tl.program_id(0) + q_block_global_idx = tl.program_id(1) + + # needed to use exp2 (exp2 -> exp conversion) + RCP_LN2 = 1.4426950408889634 + qk_scale = scale * RCP_LN2 + + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) + + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) + + if HEAD_SIZE_PADDED != HEAD_SIZE: + dim_mask = offs_d < HEAD_SIZE + else: + dim_mask = tl.full((1,), 1, dtype=tl.int1) + query_mask_0 = query_pos < cur_batch_query_len + query_mask_1 = query_offset_1 < num_query_heads + + if ALL_DECODE or BLOCK_M >= num_query_heads: + Q_cache_modifier: tl.constexpr = ".cg" + else: + Q_cache_modifier: tl.constexpr = "" + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + cache_modifier=Q_cache_modifier, + ) + + block_table_offset = seq_idx * block_table_stride + + if not USE_SINKS: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + # Prescale with RCP_LN2, needed for exp2 + M = ( + tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + * RCP_LN2 + ) + + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) + + # query-query attention bias + if USE_QQ_BIAS: + qq_bias_row_ptrs = ( + qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] + + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + + # ---- Sliding-window tile pruning -------------------- + # Default: keep previous global behavior + tile_start = 0 + tile_end = num_tiles + if SLIDING_WINDOW > 0: + # Query rows covered by this Q-block + qpos_lo = q_block_local_idx * BLOCK_Q + qpos_hi = tl.minimum( + qpos_lo + (BLOCK_M - 1) // num_queries_per_kv, + cur_batch_query_len - 1, + ) + # For sliding window, each query position q can only attend to + # keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs] + # where q_abs = context_len + q + # The union of allowed key positions for this Q-block is: + # [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi] + first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1 + last_allowed_key = context_len + qpos_hi + # Convert to tile indices and clamp + tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE) + tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles) + + KV_cache_modifier: tl.constexpr = ".cg" if ALL_DECODE else "" + # iterate through tiles (now limited to the sliding window range) + for j in range(tile_start, tile_end): + seq_offset = j * TILE_SIZE + offs_t + # to reduce the masking effect when not needed + if TILE_SIZE == BLOCK_SIZE: + tile_mask = tl.full((1,), 1, dtype=tl.int1) + else: + tile_mask = seq_offset < max_seq_prefix_len + + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) + + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) + + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + + # K : (HEAD_SIZE, TILE_SIZE) + K_load = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0, + cache_modifier=KV_cache_modifier, + ) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (TILE_SIZE, HEAD_SIZE) + V_load = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0, + cache_modifier=KV_cache_modifier, + ) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + # S : (BLOCK_M, TILE_SIZE) + # qk_scale = scale * RCP_LN2 (log_2 e) so that we can use exp2 later + S = qk_scale * tl.dot(Q, K) + + if USE_SOFTCAP: + # softcap here uses exp2 and consumes RCP_LN2 conversion. + # multiply by RCP_LN2 again to be used in later exp2 + S = apply_softcap(S, softcap) * RCP_LN2 + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) + + if SLIDING_WINDOW > 0: + S = tl.where( + (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, + S, + float("-inf"), + ) + + if USE_ALIBI_SLOPES: + # prescale w. RCP_LN2 for later exp2 + S += alibi_slope[:, None] * (seq_offset - context_len) * RCP_LN2 + + if USE_QQ_BIAS: + # compute key positions relative to query section + key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE] + # load bias only for keys that correspond to queries + is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0 + qq_bias = tl.load( + qq_bias_row_ptrs + key_rel_pos[None, :], + mask=is_query_key[None, :], # avoid OOB for context keys + other=0.0, + ) + # prescale w. RCP_LN2 for later exp2 + S += qq_bias * RCP_LN2 + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, TILE_SIZE) + P = tl.math.exp2(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.math.exp2(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. + one_over_L = 1.0 / L[:, None] + acc = acc * one_over_L + if USE_FP8: + acc = acc * tl.load(out_scale) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + + output_offset = ( + query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :] + ) + + tl.store( + output_ptr + output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + + +@triton.jit +def kernel_unified_attention_3d( + segm_output_ptr, + # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + ALL_DECODE: tl.constexpr = False, # bool +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + segm_idx = tl.program_id(2) + + # needed to use exp2 (exp2 -> exp conversion) + RCP_LN2 = 1.4426950408889634 + qk_scale = scale * RCP_LN2 + + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) + + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + + if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) + + if HEAD_SIZE_PADDED != HEAD_SIZE: + dim_mask = offs_d < HEAD_SIZE + else: + dim_mask = tl.full((1,), 1, dtype=tl.int1) + query_mask_0 = query_pos < cur_batch_query_len + query_mask_1 = query_offset_1 < num_query_heads + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + if USE_SINKS: + if segm_idx == 0: + # Prescale with RCP_LN2, needed for exp2 + M = ( + tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + * RCP_LN2 + ) + else: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) + + # query-query attention bias + if USE_QQ_BIAS: + qq_bias_row_ptrs = ( + qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] + + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + + KV_cache_modifier: tl.constexpr = ".cg" if ALL_DECODE else "" + # iterate through tiles within current segment + for j in range( + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles), + ): + seq_offset = j * TILE_SIZE + offs_t + if TILE_SIZE == BLOCK_SIZE: + tile_mask = tl.full((1,), 1, dtype=tl.int1) + else: + tile_mask = seq_offset < max_seq_prefix_len + + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) + + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) + + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + + # K : (HEAD_SIZE, TILE_SIZE) + K_load = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0, + cache_modifier=KV_cache_modifier, + ) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (TILE_SIZE, HEAD_SIZE) + V_load = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0, + cache_modifier=KV_cache_modifier, + ) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, TILE_SIZE) + # qk_scale = scale * RCP_LN2 (log_2 e) so that we can use exp2 later + S = qk_scale * tl.dot(Q, K) + + if USE_SOFTCAP: + # softcap here uses exp2 and consumes RCP_LN2 conversion. + # multiply by RCP_LN2 again to be used in later exp2 + S = apply_softcap(S, softcap) * RCP_LN2 + + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) + + if SLIDING_WINDOW > 0: + S = tl.where( + (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, + S, + float("-inf"), + ) + + if USE_ALIBI_SLOPES: + # prescale w. RCP_LN2 for later exp2 + S += alibi_slope[:, None] * (seq_offset - context_len) * RCP_LN2 + + if USE_QQ_BIAS: + # compute key positions relative to query section + key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE] + # load bias only for keys that correspond to queries + is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0 + qq_bias = tl.load( + qq_bias_row_ptrs + key_rel_pos[None, :], + mask=is_query_key[None, :], # avoid OOB for context keys + other=0.0, + ) + # prescale w. RCP_LN2 for later exp2 + S += qq_bias * RCP_LN2 + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, TILE_SIZE,) + P = tl.math.exp2(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.math.exp2(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + segm_output_offset = ( + query_offset_0[:, None].to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + segm_idx * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) + tl.store( + segm_output_ptr + segm_output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + segm_offset = ( + query_offset_0.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + + segm_idx + ) + tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, L, mask=query_mask_0 & query_mask_1) + + +@triton.jit +def reduce_segments( + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + # [num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + out_scale_inv, # float32 + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + TILE_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, +): + query_token_idx = tl.program_id(0) + query_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx( + query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False + ) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + + # create masks for subsequent loads + act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) + segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32 + ) + + if HEAD_SIZE_PADDED != HEAD_SIZE: + dim_mask = offs_d < HEAD_SIZE + else: + dim_mask = tl.full((1,), 1, dtype=tl.int1) + + # load segment maxima + segm_offset = ( + query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ) + ) + segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf")) + overall_max = tl.max(segm_max) + + # load and rescale segment exp sums + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0) + segm_expsum = segm_expsum * tl.math.exp2(segm_max - overall_max) + overall_expsum = tl.sum(segm_expsum) + + # load, rescale, and add segment attention outputs + segm_output_offset = ( + query_token_idx.to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) + segm_output = tl.load( + segm_output_ptr + segm_output_offset, + mask=segm_mask[:, None] & dim_mask[None, :], + other=0.0, + ) + segm_output *= tl.math.exp2(segm_max - overall_max)[:, None] + acc_sum = tl.sum(segm_output, axis=0) + # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 + acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + + # write result + output_offset = ( + query_token_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE_PADDED) + ) + tl.store(output_ptr + output_offset, acc, mask=dim_mask) diff --git a/aiter/ops/triton/activation.py b/aiter/ops/triton/activation.py index adb34b97eb..7bb7bbee15 100644 --- a/aiter/ops/triton/activation.py +++ b/aiter/ops/triton/activation.py @@ -2,9 +2,13 @@ import triton import triton.language as tl import torch +import aiter + +fp8_dtype = aiter.dtypes.fp8 from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.activation import ( _act_mul_and_dynamic_mxfp4_quant_kernel, + _act_mul_and_dynamic_fp8_group_quant_kernel, ) _LOGGER = AiterTritonLogger() @@ -125,3 +129,75 @@ def act_mul_and_mxfp4_quant( ) return x_fp4, blockscale_e8m0 + + +def act_mul_and_fp8_group_quant( + x: torch.Tensor, + activation: Literal["silu", "gelu", "gelu_tanh"], + group_size, + dtype_quant=fp8_dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply the activation function and quantize the result to MX FP4 format. + + Args: + x: The input tensor, typically fp16 or bf16. + activation: activation function to apply before quantization. + - It splits the features into two parts and applies the activation to the first part. + - Then, it adds the results together before quantization. + - Supports the following activations: + - "silu" + - "gelu" + - "gelu_tanh" + + scaling_mode: The method to calculate MX block scaling. + - "even" (default): `even_round` in `quark.torch.quantization.utils`. + - etc. + shuffle: Indicates whether to enable preshuffling of scales. + - When enabled, scale dimensions (X, Y) are adjusted to be multiples of 8 and 256, respectively. + Returns: + A tuple of (x_fp4, blockscale_e8m0). + """ + _LOGGER.info(f"ACT_MUL_FP8_GROUP_QUANT: x={tuple(x.shape)} activation={activation}") + # Assume x is 2D-Tensor for now + M, N = x.shape + assert N % 2 == 0 + + N_half = N // 2 + scaleN = triton.cdiv(N, group_size) + x_fp8 = torch.empty((M, N_half), dtype=dtype_quant, device=x.device) + out_bs = torch.empty( + (M, triton.cdiv(N_half, group_size)), dtype=torch.float32, device=x.device + ) + + DTYPE_MAX = ( + torch.finfo(x_fp8.dtype).max + if torch.is_floating_point(x_fp8) + else torch.iinfo(x_fp8.dtype).max + ) + BLOCK_SIZE_N = group_size + + grid = ( + M, + triton.cdiv(N_half, BLOCK_SIZE_N), + ) + _act_mul_and_dynamic_fp8_group_quant_kernel[grid]( + x, + x_fp8, + out_bs, + *x.stride(), + *x_fp8.stride(), + *out_bs.stride(), + N=N_half, + ACTIVATION=activation, + scaleN=scaleN, + BLOCK_SIZE_N=BLOCK_SIZE_N, + QUANT_BLOCK_SIZE=group_size, + DTYPE_MAX=DTYPE_MAX, + DTYPE_MIN=-DTYPE_MAX, + # num_warps=NUM_WARPS, + # waves_per_eu=0, + # num_stages=1, + ) + + return x_fp8, out_bs diff --git a/aiter/ops/triton/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py b/aiter/ops/triton/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py index a86f7d655a..64d2e73f99 100644 --- a/aiter/ops/triton/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py +++ b/aiter/ops/triton/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py @@ -23,6 +23,7 @@ def batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( splitK: Optional[int] = None, YQ: Optional[torch.Tensor] = None, transpose_bm: Optional[bool] = False, + transpose_bm_in: Optional[bool] = False, config: Optional[dict] = None, ): """ @@ -33,19 +34,28 @@ def batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( 2D one before being applied. Key parameters: - - XQ: Batch tensor XQ with shape (B, M, K). + - XQ: Batch tensor XQ with shape (B, M, K) if transpose_bm_in == False else (M, B, K). - WQ: Batch tensor WQ with shape (B, N, K). - W_scale: Second scale batch tensor with shape (1, ). - Bias: Bias batch tensor with shape (B, 1, N). - YQ: Output Matrix Y with shape (B, M, N). If this is none, then it's created by this API and returned as output Returns: - - YQ: The output batch tensor with shape (B, M, N). + - YQ: The output batch tensor with shape (B, M, N) if transpose_bm == False else (M, B, N). """ # Check constraints. - assert X.shape[0] == WQ.shape[0], "Incompatible Batch dimensions!!!" - assert X.shape[2] == WQ.shape[2], "Incompatible K dimensions!!!" + if not transpose_bm_in: + B = X.shape[0] + M = X.shape[1] + else: + M = X.shape[0] + B = X.shape[1] + K = X.shape[2] + N = WQ.shape[1] + + assert B == WQ.shape[0], "Incompatible Batch dimensions!!!" + assert K == WQ.shape[2], "Incompatible K dimensions!!!" assert ( triton.next_power_of_2(group_size) == group_size ), "group_size mush be power of 2" @@ -55,11 +65,6 @@ def batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( ], f"Output {dtype=} is currently not supported in batched_gemm_a8w8" assert splitK is None, "Currently, there isn't any support for splitK on Triton" - B = X.shape[0] - M = X.shape[1] - K = X.shape[2] - N = WQ.shape[1] - WQ = WQ.transpose(1, 2) has_bias = bias is not None @@ -104,8 +109,8 @@ def batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( M, N, K, - X.stride(0), - X.stride(1), + X.stride(0) if not transpose_bm_in else X.stride(1), + X.stride(1) if not transpose_bm_in else X.stride(0), X.stride(2), WQ.stride(0), WQ.stride(1), diff --git a/aiter/ops/triton/configs/gemm/MI300X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json b/aiter/ops/triton/configs/gemm/MI300X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json new file mode 100644 index 0000000000..b1a39b0746 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI300X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI300X-GEMM-A16W16.json b/aiter/ops/triton/configs/gemm/MI300X-GEMM-A16W16.json index dffd8ec71f..7d43b7efa1 100644 --- a/aiter/ops/triton/configs/gemm/MI300X-GEMM-A16W16.json +++ b/aiter/ops/triton/configs/gemm/MI300X-GEMM-A16W16.json @@ -9,7 +9,7 @@ "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": ".cg", - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, "M_LEQ_128": { @@ -22,7 +22,7 @@ "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": ".cg", - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, "M_LEQ_256": { @@ -35,7 +35,7 @@ "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": ".cg", - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, "M_LEQ_512": { @@ -48,7 +48,7 @@ "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": null, - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, "M_LEQ_2048": { @@ -61,10 +61,10 @@ "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": null, - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, - "M_GEQ_4096": { + "any": { "BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, @@ -74,7 +74,7 @@ "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": null, - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 } } diff --git a/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=128-K=512.json b/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=128-K=512.json new file mode 100644 index 0000000000..3823c3fc2e --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=128-K=512.json @@ -0,0 +1,81 @@ +{ + "small" : { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + }, + "medium_M32" : { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + }, + "medium_M64" : { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + }, + "medium_M128" : { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + }, + "large" : { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + }, + "xlarge" : { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + }, + "any" : { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + } +} + + \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=512-K=128.json b/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=512-K=128.json new file mode 100644 index 0000000000..fdd6010f34 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT-N=512-K=128.json @@ -0,0 +1,81 @@ +{ + "small" : { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + }, + "medium_M32" : { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + }, + "medium_M64" : { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + }, + "medium_M128" : { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + }, + "large" : { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + }, + "xlarge" : { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + }, + "any" : { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "kpack": 2, + "cache_modifier": ".cg" + } +} + + \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI350X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16-N8=512-N16=256-K=7168.json b/aiter/ops/triton/configs/gemm/MI350X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16-N8=512-N16=256-K=7168.json new file mode 100644 index 0000000000..8658e29031 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16-N8=512-N16=256-K=7168.json @@ -0,0 +1,110 @@ +{ + "small": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M32": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "xlarge_M1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "xlarge_M2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json b/aiter/ops/triton/configs/gemm/MI350X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json new file mode 100644 index 0000000000..b1a39b0746 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=128-K=2880.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=128-K=2880.json new file mode 100644 index 0000000000..b6eca2c53d --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=128-K=2880.json @@ -0,0 +1,145 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 1, + "BLOCK_SIZE_N": 4, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 4, + "num_stages": 2 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 1, + "BLOCK_SIZE_N": 4, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 4, + "num_stages": 2 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 1, + "BLOCK_SIZE_N": 4, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 4, + "num_stages": 2 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=2280-K=512.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=2280-K=512.json new file mode 100644 index 0000000000..cb26dde3dd --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=2280-K=512.json @@ -0,0 +1,145 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 4, + "num_stages": 2 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 3 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 4, + "num_stages": 2 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 4, + "num_stages": 2 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=256-K=7168.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=256-K=7168.json index 4056d9b04e..94947a5a1a 100644 --- a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=256-K=7168.json +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=256-K=7168.json @@ -9,7 +9,7 @@ "waves_per_eu": 1, "matrix_instr_nonkdim": 16, "cache_modifier": ".cg", - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, "M_LEQ_128": { @@ -22,7 +22,7 @@ "waves_per_eu": 1, "matrix_instr_nonkdim": 16, "cache_modifier": ".cg", - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, "M_LEQ_256": { @@ -35,7 +35,7 @@ "waves_per_eu": 1, "matrix_instr_nonkdim": 32, "cache_modifier": null, - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, "M_LEQ_512": { @@ -48,7 +48,7 @@ "waves_per_eu": 1, "matrix_instr_nonkdim": 32, "cache_modifier": null, - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, "M_LEQ_2048": { @@ -61,7 +61,7 @@ "waves_per_eu": 1, "matrix_instr_nonkdim": 32, "cache_modifier": null, - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, "M_GEQ_4096": { @@ -74,7 +74,7 @@ "waves_per_eu": 1, "matrix_instr_nonkdim": 32, "cache_modifier": null, - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 } } diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=2880-K=4096.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=2880-K=4096.json new file mode 100644 index 0000000000..0773aa69f6 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=2880-K=4096.json @@ -0,0 +1,145 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 4, + "num_stages": 2 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=5120-K=2880.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=5120-K=2880.json new file mode 100644 index 0000000000..d80aede426 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=5120-K=2880.json @@ -0,0 +1,145 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 4, + "num_stages": 2 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=640-K=2880.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=640-K=2880.json new file mode 100644 index 0000000000..df01dc2c7c --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16-N=640-K=2880.json @@ -0,0 +1,145 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 4, + "num_stages": 2 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "M_LEQ_8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16.json index 06d6583b04..34d1d730ff 100644 --- a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16.json +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A16W16.json @@ -9,7 +9,7 @@ "waves_per_eu": 3, "matrix_instr_nonkdim": 16, "cache_modifier": ".cg", - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, "M_LEQ_128": { @@ -22,7 +22,7 @@ "waves_per_eu": 3, "matrix_instr_nonkdim": 16, "cache_modifier": ".cg", - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, "M_LEQ_256": { @@ -35,7 +35,7 @@ "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": ".cg", - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, "M_LEQ_512": { @@ -48,7 +48,7 @@ "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": ".cg", - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, "M_LEQ_2048": { @@ -61,10 +61,10 @@ "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": null, - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 }, - "M_GEQ_4096": { + "any": { "BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, @@ -74,7 +74,7 @@ "waves_per_eu": 2, "matrix_instr_nonkdim": 16, "cache_modifier": null, - "NUM_KSPLIT": 1, + "NUM_KSPLIT": 1, "kpack": 1 } } diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=2112-K=7168.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=2112-K=7168.json new file mode 100644 index 0000000000..6247a3c257 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=2112-K=7168.json @@ -0,0 +1,62 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=3072-K=1536.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=3072-K=1536.json new file mode 100644 index 0000000000..9a0a9fd275 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=3072-K=1536.json @@ -0,0 +1,62 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=4096-K=7168.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=4096-K=7168.json new file mode 100644 index 0000000000..aa648253c4 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=4096-K=7168.json @@ -0,0 +1,62 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "medium_M128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=4608-K=7168.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=4608-K=7168.json new file mode 100644 index 0000000000..9c2d9ed42d --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=4608-K=7168.json @@ -0,0 +1,62 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=512-K=7168.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=512-K=7168.json new file mode 100644 index 0000000000..b00e6b8369 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=512-K=7168.json @@ -0,0 +1,62 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=7168-K=2048.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=7168-K=2048.json new file mode 100644 index 0000000000..6f8653065b --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=7168-K=2048.json @@ -0,0 +1,62 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=7168-K=256.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=7168-K=256.json new file mode 100644 index 0000000000..d059b3c3aa --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-A8W8_BLOCKSCALE-N=7168-K=256.json @@ -0,0 +1,62 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "medium_M32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=10240-K=8192.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=10240-K=8192.json new file mode 100644 index 0000000000..3036a5555c --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=10240-K=8192.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 1024, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 1024, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=106496-K=16384.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=106496-K=16384.json new file mode 100644 index 0000000000..9895a4dc51 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=106496-K=16384.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=16384-K=16384.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=16384-K=16384.json new file mode 100644 index 0000000000..7174492af2 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=16384-K=16384.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=16384-K=53248.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=16384-K=53248.json new file mode 100644 index 0000000000..6e56e82027 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=16384-K=53248.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=18432-K=16384.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=18432-K=16384.json new file mode 100644 index 0000000000..13fe8985da --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=18432-K=16384.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=57344-K=8192.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=57344-K=8192.json new file mode 100644 index 0000000000..909689ce23 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=57344-K=8192.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=8192-K=28672.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=8192-K=28672.json new file mode 100644 index 0000000000..b7822a3a3b --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=8192-K=28672.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 8 + }, + "medium_M128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=8192-K=8192.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=8192-K=8192.json new file mode 100644 index 0000000000..1989ff517f --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED-N=8192-K=8192.json @@ -0,0 +1,86 @@ +{ + "small": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 1024, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 1024, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 1024, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 1024, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "large": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED.json b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED.json new file mode 100644 index 0000000000..eebf5eff50 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/MI350X-GEMM-AFP4WFP4_PRESHUFFLED.json @@ -0,0 +1,87 @@ +{ + "small": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "small_M16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "medium_M32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "medium_M64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "medium_M128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "large": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "xlarge": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } + +} diff --git a/aiter/ops/triton/configs/gemm/aot/README.md b/aiter/ops/triton/configs/gemm/aot/README.md new file mode 100644 index 0000000000..f320237595 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/README.md @@ -0,0 +1,6 @@ +# WARNING + +This is the first iteration of aot compilation for triton +All compiled modules are FP4 preshuffled GEMMs with TN layout and bf16 dtype + +All binary files are generated using triton==3.5.0+gitc172d539 \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..9a721b6230 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..218b217517 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "e90d4ba9cf14219bef1bca72767ed05991913eb79484a5b706cb25d9f2f71474", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 43520, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..a76261d24b Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..4485032426 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "e80a3f3a19a5da27236f25e468c4b22caa88c28f65793d17c3d2045fe972817c", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 4, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 10752, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..19aa40e784 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..1e5bb1dfae --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "cb12dc32b0ed1a5ac880a6dd3bee50fb59d11e1a8eeccc3ae8153c968e7f2c75", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 4, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 13056, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..8dcc5280de Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..fca09fa225 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=1-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "347e0c55794ac0ca235e8b969a4b5a5268100a128f24dcce30fe2005b2bc21b1", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 4, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 16896, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..febfd8cf3b Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..84e66f815f --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "a729967cd59e3c39a6f61dd259cc2b7cd9768909003d37d03d9dc7dae7280b9e", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 52224, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..1620f0f01d Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..56a2dfd70c --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "a38121d8f5709315553f0016ca0e08c77bfd16fd57e336ed676b85615be00762", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 4, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 13056, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..cfdd8d48cc Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..fc4ab35831 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "4af84e6c0b5acb21f71e7f71ab43f43a465dd74734d7c6def0d9fc859c471c1f", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 10752, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..febfd8cf3b Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..84e66f815f --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=16-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "a729967cd59e3c39a6f61dd259cc2b7cd9768909003d37d03d9dc7dae7280b9e", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 52224, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..d5d5dde2c6 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..046da07114 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "41f673542e895bf56edb8e6a137febf789c28a9da5b4693a1065490a62336656", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 43520, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..6ce75e81b2 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..7b2c5ab8de --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "1b29a972364a81e3844504157096f1a0ca2164836cee9758c885f562921d6f0c", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 4, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 10752, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..a43f1de3b1 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..ac8df07cd5 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "96e931c980f9bd1b0d7ba209973d637dae985113c68e23d0476ea6a3789b77f4", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 4, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 13056, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..bdcde11a55 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..3c5b94ec80 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=2-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "3a4dedb1720cc3cc439e59bf26afb541bb9591e17d932fe1696e4ac46fe1c376", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 4, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 16896, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..ba18fd23df Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..8dc6746eff --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "518af245b3686a62c8aae8b677a2e83177124a639e544e12c11c00b9797474fd", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 9728, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..cf679ceb39 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..62c83cddea --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "1737c0a38627fe5406a6244d0c66b46e3b98dfd8daf99c31b2c2ab219ffd8249", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 1, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 10752, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..b34288dac8 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..859c3bf7e2 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "84f327b5729d25ec4ad344f8a9b211f9c9786815df9873b33e1a44d2cdf8e580", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 19456, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..b7504c5898 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..8b0d6ebf34 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=32-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "4f511f2573c219ee1928e586a5facd24ea5ddbd2f6314d14387f45c2ca36905b", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 4, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 43008, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..d5d5dde2c6 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..046da07114 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "41f673542e895bf56edb8e6a137febf789c28a9da5b4693a1065490a62336656", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 43520, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..6ce75e81b2 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..7b2c5ab8de --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "1b29a972364a81e3844504157096f1a0ca2164836cee9758c885f562921d6f0c", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 4, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 10752, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..a43f1de3b1 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..ac8df07cd5 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "96e931c980f9bd1b0d7ba209973d637dae985113c68e23d0476ea6a3789b77f4", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 4, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 13056, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..bdcde11a55 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..3c5b94ec80 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=4-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "3a4dedb1720cc3cc439e59bf26afb541bb9591e17d932fe1696e4ac46fe1c376", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 4, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 16896, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..36a56dbea1 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..a6f0809dde --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "f22922a7294924d71ca6c72a6b4ac34c07ff79ccf09d45e9fea0fcec2660ee0c", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 19456, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..b5e58ce97c Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..9f7f4d9500 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "7ee2132b54aabbbef6a1f5cc7a99ad94d8c6ee8420e5a6fb8702168c0df06a5d", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 1, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 38912, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..4240460535 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..5373477c35 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "1007fa9b77c1c41ab0d7f1875b4474e4c8e58481c2f80bfcfbe0ee0131caa0e5", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 1, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 38912, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..ec1d731f5a Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..b0be146e7f --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=64-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "e9a4e058b4c9508aa7b4c8c5c8ff9bba7f3a3c069f2492dbac912115e7a4108a", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 4, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 77824, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..d5d5dde2c6 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..046da07114 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=10240-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "41f673542e895bf56edb8e6a137febf789c28a9da5b4693a1065490a62336656", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 43520, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..6ce75e81b2 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..7b2c5ab8de --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=57344-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "1b29a972364a81e3844504157096f1a0ca2164836cee9758c885f562921d6f0c", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 4, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 10752, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..a43f1de3b1 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..ac8df07cd5 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=8192-K=28672/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "96e931c980f9bd1b0d7ba209973d637dae985113c68e23d0476ea6a3789b77f4", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 2, "waves_per_eu": 4, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 13056, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco new file mode 100644 index 0000000000..bdcde11a55 Binary files /dev/null and b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.hsaco differ diff --git a/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json new file mode 100644 index 0000000000..3c5b94ec80 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/aot/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales_M=8-N=8192-K=8192/_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.json @@ -0,0 +1 @@ +{"hash": "3a4dedb1720cc3cc439e59bf26afb541bb9591e17d932fe1696e4ac46fe1c376", "target": {"backend": "hip", "arch": "gfx950", "warp_size": 64}, "num_warps": 4, "waves_per_eu": 2, "num_stages": 2, "num_ctas": 1, "extern_libs": [["ocml", "/app/triton-tot/python/triton/backends/amd/lib/ocml.bc"], ["ockl", "/app/triton-tot/python/triton/backends/amd/lib/ockl.bc"]], "cluster_dims": [1, 1, 1], "debug": false, "sanitize_overflow": true, "arch": "gfx950", "supported_fp8_dtypes": ["fp8e4b8", "fp8e4nv", "fp8e5", "fp8e5b16"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b8", "fp8e5b16"], "default_dot_input_precision": "ieee", "allowed_dot_input_precisions": ["ieee", "bf16x3", "bf16x6"], "enable_fp_fusion": true, "launch_cooperative_grid": false, "matrix_instr_nonkdim": 16, "kpack": 1, "allow_flush_denorm": false, "max_num_imprecise_acc_default": 0, "backend_name": "hip", "instrumentation_mode": "", "schedule_hint": "none", "warp_size": 64, "AMDGCN_USE_BUFFER_OPS": "true", "TRITON_HIP_USE_ASYNC_COPY": "true", "TRITON_HIP_USE_BLOCK_PINGPONG": "true", "triton_version": "3.5.0", "shared": 16896, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "_gemm_afp4_wfp4_kernel_preshuffled_weight_scales"} \ No newline at end of file diff --git a/aiter/ops/triton/fused_add_rmsnorm_pad.py b/aiter/ops/triton/fused_add_rmsnorm_pad.py new file mode 100644 index 0000000000..d902e3e44e --- /dev/null +++ b/aiter/ops/triton/fused_add_rmsnorm_pad.py @@ -0,0 +1,54 @@ +import torch +import triton +from aiter.ops.triton._triton_kernels.fused_add_rmsnorm_pad import ( + _fused_add_rmsnorm_pad, +) + + +def fused_add_rmsnorm_pad( + x: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + res: torch.Tensor = None, + x_pad_to_multiple: int = 0, +): + M, N = x.shape + + if x_pad_to_multiple > 0: + N_out = triton.cdiv(N, x_pad_to_multiple) * x_pad_to_multiple + else: + N_out = N + out = torch.empty((M, N_out), dtype=x.dtype, device=x.device) + + res_out = None + if res is not None: + M2, N2 = res.shape + assert M == M2, "Shape error!" + assert N == N2, "Shape error!" + res_out = torch.empty((M, N), dtype=res.dtype, device=res.device) + BLOCK_SIZE_N = triton.next_power_of_2(N_out) + _fused_add_rmsnorm_pad[(M,)]( + x, + res, + out, + res_out, + weight, + epsilon, + M, + N, + N_out, + x.stride(0), + x.stride(1), + res.stride(0) if res is not None else 0, + res.stride(1) if res is not None else 0, + out.stride(0), + out.stride(1), + res_out.stride(0) if res is not None else 0, + res_out.stride(1) if res is not None else 0, + HAS_RES=(res is not None), + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + + if res is not None: + return out, res_out + return out diff --git a/aiter/ops/triton/fused_fp8_quant.py b/aiter/ops/triton/fused_fp8_quant.py new file mode 100644 index 0000000000..39e1c58777 --- /dev/null +++ b/aiter/ops/triton/fused_fp8_quant.py @@ -0,0 +1,336 @@ +from typing import Optional +import torch +import triton +import aiter +from aiter.ops.triton._triton_kernels.fused_fp8_quant import ( + _fused_rms_fp8_group_quant_kernel, + _fused_flatten_fp8_group_quant_kernel, + _fused_reduce_act_mul_fp8_group_quant, +) +from aiter.ops.triton._triton_kernels.activation import ( + _get_activation_from_str, +) +from aiter.ops.triton.utils.logger import AiterTritonLogger + +_LOGGER = AiterTritonLogger() + + +fp8_dtype = aiter.dtypes.fp8 + + +def fused_rms_fp8_group_quant( + inp1, + inp1_weight, + inp1_epsilon, + inp2=None, + inp2_weight=None, + inp2_epsilon=None, + group_size=128, + dtype_quant=fp8_dtype, + res1=None, + output_unquantized_inp1=False, +): + """ + This op contains several steps: + 1. if res1 is not None, inp1 = inp1 + res1, and store inp1 to out_res1 + 2. perform RMS norm along the last dimenion for inp1 + 3. if inp2 is not None, perform RMS norm along the last dimenion for inp2 + 4. perform fp8 quantization for inp1 only + + Key parameters: + - x: Matrix X with shape (M, N1, N2). + + Returns: + - out1_fp8: The output matrix with shape (M, N1). + - out1_bs: The output matrix with shape (M, cdiv(N1, group_size)). + - out1: The output matrix with shape (M, N1). + - out2: The output matrix with shape (M, N2). + - out_res1: The output matrix with shape (M, N1). + - out1: The output matrix with shape (M, N1). + """ + + M, N1 = inp1.shape + BLOCK_SIZE_N = max(triton.next_power_of_2(N1), group_size) + if inp2 is not None: + M2, N2 = inp2.shape + BLOCK_SIZE_N = max(triton.next_power_of_2(N2), BLOCK_SIZE_N) + assert ( + M == M2 + ), "The leading dimension should be identical between inp1 and inp2" + else: + N2 = 0 + out1_fp8 = torch.empty((M, N1), dtype=dtype_quant, device=inp1.device) + out1_bs = torch.empty( + (M, (N1 + group_size - 1) // group_size), + dtype=torch.float32, + device=inp1.device, + ) + + out2 = None + out2_row_stride = 0 + out2_col_stride = 0 + inp2_row_stride = 0 + inp2_col_stride = 0 + if inp2 is not None: + out2 = torch.empty((M, N2), dtype=inp1.dtype, device=inp1.device) + inp2_row_stride = inp2.stride(0) + inp2_col_stride = inp2.stride(1) + out2_row_stride = out2.stride(0) + out2_col_stride = out2.stride(1) + + out1 = None + out1_row_stride = 0 + out1_col_stride = 0 + if output_unquantized_inp1: + out1 = torch.empty((M, N1), dtype=inp1.dtype, device=inp1.device) + out1_row_stride = out1.stride(0) + out1_col_stride = out1.stride(1) + + BLOCK_SIZE_N = max(BLOCK_SIZE_N, group_size) + out_res1 = None + res1_row_stride = 0 + res1_col_stride = 0 + out_res1_row_stride = 0 + out_res1_col_stride = 0 + if res1 is not None: + Mr, Nr = res1.shape + assert ( + M == Mr and N1 == Nr + ), "The shape should be identical between inp1 and res1" + out_res1 = torch.empty((M, N1), dtype=inp1.dtype, device=inp1.device) + res1_row_stride = res1.stride(0) + res1_col_stride = res1.stride(1) + out_res1_row_stride = out_res1.stride(0) + out_res1_col_stride = out_res1.stride(1) + + if BLOCK_SIZE_N <= 512: + num_warps = 1 + elif BLOCK_SIZE_N <= 2048: + num_warps = 4 + elif BLOCK_SIZE_N <= 4096: + num_warps = 8 + else: + num_warps = 16 + + DTYPE_MAX = ( + torch.finfo(out1_fp8.dtype).max + if torch.is_floating_point(out1_fp8) + else torch.iinfo(out1_fp8.dtype).max + ) + _fused_rms_fp8_group_quant_kernel[(M,)]( + inp1, + inp1_weight, + inp2, + inp2_weight, + res1, + out1_fp8, + out1_bs, + out2, + out_res1, + out1, + inp1_epsilon, + inp2_epsilon, + M, + N1, + N2, + inp1.stride(0), + inp2_row_stride, + inp1.stride(1), + inp2_col_stride, + res1_row_stride, + res1_col_stride, + out1_fp8.stride(0), + out1_fp8.stride(1), + out1_bs.stride(0), + out1_bs.stride(1), + out2_row_stride, + out2_col_stride, + out_res1_row_stride, + out_res1_col_stride, + out1_row_stride, + out1_col_stride, + BLOCK_SIZE_N=BLOCK_SIZE_N, + QUANT_BLOCK_SIZE=group_size, + DTYPE_MAX=DTYPE_MAX, + DTYPE_MIN=-DTYPE_MAX, + HAVE_SECOND_INPUT=(inp2 is not None), + FIRST_INPUT_RES=(res1 is not None), + FIRST_INPUT_OUT=output_unquantized_inp1, + num_warps=num_warps, + ) + + return (out1_fp8, out1_bs), out1, out2, out_res1 + + +def fused_flatten_fp8_group_quant( + x: torch.Tensor, + group_size, + dtype_quant=fp8_dtype, +): + """ + Flatten the last two dimension of x and perform FP8 per-token group quantization along the last dimension + + Key parameters: + - x: Matrix X with shape (M, N1, N2). + + Returns: + - out: The output matrix with shape (M, N1 * N2). + - out_block_scales: The output matrix with shape (M, cdiv((N1 * N2), group_size)). + """ + M, N1, N2 = x.shape + + BLOCK_SIZE_N2 = max(triton.next_power_of_2(N2), group_size) + N = N1 * N2 + out = torch.empty((M, N), dtype=dtype_quant, device=x.device) + out_block_scales = torch.empty( + (M, triton.cdiv(N, group_size)), dtype=torch.float32, device=x.device + ) + + DTYPE_MAX = ( + torch.finfo(out.dtype).max + if torch.is_floating_point(out) + else torch.iinfo(out.dtype).max + ) + grid = ( + M, + N1, + ) + _fused_flatten_fp8_group_quant_kernel[grid]( + x, + out, + out_block_scales, + *x.stride(), + *out.stride(), + *out_block_scales.stride(), + N2, + BLOCK_SIZE_N2=BLOCK_SIZE_N2, + QUANT_BLOCK_SIZE=group_size, + DTYPE_MAX=DTYPE_MAX, + DTYPE_MIN=-DTYPE_MAX, + ) + + return out, out_block_scales + + +def fused_reduce_act_mul_fp8_group_quant( + x: torch.Tensor, + activation: str = "silu", + x2: Optional[torch.Tensor] = None, + group_size=128, + dtype_quant=fp8_dtype, + dtype: Optional[float] = torch.bfloat16, +): + """ + Apply reduction along the first dimension and apply the activation function + per-token group quantization. + If x2 is provided, the only reduction along the first dimension is applied to x2 + + Args: + if x is 3-dim, + x: (SPK, M, 2*N1), dtype = fp32. + x2: (SPK, M, 2*N1), dtype = fp32. + + if x is 2-dim, + x: (M, N2), dtype = fp16 or bf16. + x2 must be None + the kernel is essentially identical to aiter.ops.triton.activation.act_mul_and_fp8_group_quant + + activation: activation function to apply before quantization. + - It splits the features into two parts and applies the activation to the first part. + - Then, it adds the results together before quantization. + - Supports the following activations: + - "silu" + - "gelu" + - "gelu_tanh" + + Returns: + tuple: (y, y_scale), y2 + y: (M, N1), dtype = dtype_quant + y_scale: (M, cdiv(N1, group_size)), dtype = fp32 + y2: (M, N2), dtype = dtype + """ + _LOGGER.info(f"FUSED_REDUCTION_ACT_MUL_FP8_GROUP_QUANT: x={tuple(x.shape)}") + + assert ( + x.dim() == 2 or x.dim() == 3 + ), "The number of dimentions for x should be 2 or 3" + X_HAS_SPLITK = False + x_num_splitk = 1 + N2 = 1 + y2 = None + if x.dim() == 3: + x_num_splitk, M, N1 = x.shape + x_num_splitk, _, N2 = x2.shape + assert ( + x.shape[0] == x2.shape[0] and x.shape[1] == x2.shape[1] + ), "The first two dimensions should be identical between x and x2" + assert ( + x_num_splitk > 1 + ), "x.shape[0] should be larger then 1 in x.dim() == 3 cases" + X_HAS_SPLITK = True + y2 = torch.empty((M, N2), dtype=dtype, device=x2.device) + else: + M, N1 = x.shape + assert x2 is None, "x2 should be None in x.dim() == 2 cases" + + assert ( + N1 % 2 == 0 + ), "The last dimension for x1 should be multiple of 2 for acitvation and multiplication" + N1 = N1 // 2 + + y = torch.empty((M, N1), dtype=dtype_quant, device=x.device) + y_scale = torch.empty( + (M, (N1 + group_size - 1) // group_size), + dtype=torch.float32, + device=x.device, + ) + + BLOCK_SIZE_N1 = max(triton.next_power_of_2(N1), group_size) + BLOCK_SIZE_N2 = max(triton.next_power_of_2(N2), 32) + BLOCK_SIZE_M2 = 1 if M <= 128 else 4 + X_MASK = N1 % BLOCK_SIZE_N1 != 0 + + DTYPE_MAX = ( + torch.finfo(y.dtype).max + if torch.is_floating_point(y) + else torch.iinfo(y.dtype).max + ) + num_pid = M + if X_HAS_SPLITK: + num_pid += triton.cdiv(M, BLOCK_SIZE_M2) * triton.cdiv(N2, BLOCK_SIZE_N2) + grid = (num_pid,) + _fused_reduce_act_mul_fp8_group_quant[grid]( + x, + y, + y_scale, + x2, + y2, + M, + N1, + N2, + 0 if not X_HAS_SPLITK else x.stride(0), + x.stride(0) if not X_HAS_SPLITK else x.stride(1), + x.stride(1) if not X_HAS_SPLITK else x.stride(2), + y.stride(0), + y.stride(1), + y_scale.stride(0), + y_scale.stride(1), + 0 if not X_HAS_SPLITK else x2.stride(0), + 0 if not X_HAS_SPLITK else x2.stride(1), + 0 if not X_HAS_SPLITK else x2.stride(2), + 0 if not X_HAS_SPLITK else y2.stride(0), + 0 if not X_HAS_SPLITK else y2.stride(1), + ACTIVATION=_get_activation_from_str(activation) if activation else "", + BLOCK_SIZE_M2=BLOCK_SIZE_M2, + BLOCK_SIZE_N1=BLOCK_SIZE_N1, + BLOCK_SIZE_N2=BLOCK_SIZE_N2, + QUANT_BLOCK_SIZE=group_size, + DTYPE_MAX=DTYPE_MAX, + DTYPE_MIN=-DTYPE_MAX, + X_HAS_SPLITK=X_HAS_SPLITK, + X_NUM_KSPLIT=x_num_splitk, + X_NUM_KSPLIT_POW2=triton.next_power_of_2(x_num_splitk), + X_MASK=X_MASK, + num_warps=1 if max(BLOCK_SIZE_N1, BLOCK_SIZE_N2) <= 512 else 4, + ) + + return (y, y_scale), y2 diff --git a/aiter/ops/triton/fused_gemm_a8w8_blockscale_a16w16.py b/aiter/ops/triton/fused_gemm_a8w8_blockscale_a16w16.py new file mode 100644 index 0000000000..40f5f62633 --- /dev/null +++ b/aiter/ops/triton/fused_gemm_a8w8_blockscale_a16w16.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from typing import Optional +import functools +import json +import os +import torch +import triton +import triton.language as tl +from aiter.ops.triton._triton_kernels.fused_gemm_a8w8_blockscale_a16w16 import ( + _fused_gemm_a8w8_blockscale_a16w16_kernel, + _fused_gemm_a8w8_blockscale_a16w16_reduce_kernel, + _get_config, +) +from aiter.ops.triton.utils.logger import AiterTritonLogger + +_LOGGER = AiterTritonLogger() + + +def fused_gemm_a8w8_blockscale_a16w16( + x_fp8: torch.Tensor, + w_fp8: torch.Tensor, + x_fp8_scale: torch.Tensor, + w_fp8_scale: torch.Tensor, + x_bf16: torch.Tensor, + w_bf16: torch.Tensor, + bias_fp8: Optional[torch.Tensor] = None, + bias_bf16: Optional[torch.Tensor] = None, + dtype: Optional[float] = torch.bfloat16, + y_fp8: Optional[torch.Tensor] = None, + y_bf16: Optional[torch.Tensor] = None, + skip_reduce: Optional[bool] = False, + config: Optional[dict] = None, +): + """ + Computes the 8 bit matmul Y = X x WT + B using the block-scale quantization approach for x_fp8 and w_fp8. + Computes the 16 bit matmul Y = X x WT + B for x_bf16 and w_bf16 + + This fusion is primarily aiming for fusing the gate up-projections and MOE gating: + gate up-projections: (M, K) x (2N, K) = (M, 2N) + MOE gating: (M, K) x (N, K) + (N, ) = (M, N) + + Key parameters: + - x_fp8: Matrix X with shape (M, K). + - w_fp8: Matrix W with shape (N_fp8, K). + - x_fp8_scale: Scale tensor for X with shape (M, *scale_k). + - w_fp8_scale: Scale tensor for W with shape (**scale_n, *scale_k). + - x_bf16: Matrix X with shape (M, K). + - w_bf16: Matrix W with shape (N_fp8, K). + + Note: M, N, K must be identical for x_fp8 and x_bf16, but the N-dim fow w_fp8 and w_bf16 can be different + + Returns: + - Y: The output matrix with shape (M, N). + + *scale_k = (K + scale_block_size_k - 1) // scale_block_size_k + **scale_n = (N_fp8 + scale_block_size_n - 1) // scale_block_size_n + """ + _LOGGER.info( + f"FUSED_GEMM_A8W8_BLOCKSCALE_A16W16: x_fp8={tuple(x_fp8.shape)} w_fp8={tuple(w_fp8.shape)} x_fp8_scale={tuple(x_fp8_scale.shape)} w_scale={tuple(w_fp8_scale.shape)} x_bf16={tuple(x_bf16.shape)} w_bf16={tuple(w_bf16.shape)}" + ) + + M, K = x_fp8.shape + N_fp8, K = w_fp8.shape + M, K = x_bf16.shape + N_bf16, K = w_bf16.shape + + # Check constraints. + assert ( + x_fp8.shape[0] == x_bf16.shape[0] + ), "M-dim should be identical for x_fp8 and x_bf16" + assert ( + x_fp8.shape[1] == x_bf16.shape[1] + ), "K-dim should be identical for x_fp8 and x_bf16" + assert x_fp8.shape[1] == w_fp8.shape[1], "Incompatible dimensions!!!" + assert w_bf16.shape[1] == w_bf16.shape[1], "Incompatible dimensions!!!" + + # Transpose w and w_scale + w_fp8 = w_fp8.T + w_bf16 = w_bf16.T + w_fp8_scale = w_fp8_scale.T + + if config is None: + config = _get_config(M, N_fp8, N_bf16, K) + + if y_fp8 is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): + y_fp8 = torch.empty((M, N_fp8), dtype=dtype, device=x_fp8.device) + + if y_bf16 is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): + y_bf16 = torch.empty((M, N_bf16), dtype=dtype, device=x_bf16.device) + + config["SPLITK_BLOCK_SIZE"] = triton.cdiv(K, config["NUM_KSPLIT"]) + if config["NUM_KSPLIT"] > 1: + y_fp8_pp = torch.empty( + (config["NUM_KSPLIT"], M, N_fp8), + dtype=torch.float32, + device=y_fp8.device if y_fp8 is not None else x_fp8.device, + ) + y_bf16_pp = torch.empty( + (config["NUM_KSPLIT"], M, N_bf16), + dtype=torch.float32, + device=y_bf16.device if y_bf16 is not None else x_bf16.device, + ) + else: + y_fp8_pp = None + y_bf16_pp = None + + if config["BLOCK_SIZE_K"] > config["SPLITK_BLOCK_SIZE"]: + config["BLOCK_SIZE_K"] = triton.next_power_of_2(config["SPLITK_BLOCK_SIZE"]) + if config["BLOCK_SIZE_K"] > config["SPLITK_BLOCK_SIZE"]: + config["BLOCK_SIZE_K"] = config["BLOCK_SIZE_K"] // 4 + config["BLOCK_SIZE_K"] = max(config["BLOCK_SIZE_K"], 16) + + # Scale block sizes + # TODO: need a better way to pass scale block sizes around + config["GROUP_K"] = triton.next_power_of_2(triton.cdiv(K, w_fp8_scale.shape[0])) + config["GROUP_N"] = triton.next_power_of_2(triton.cdiv(N_fp8, w_fp8_scale.shape[1])) + + # grid = (config["NUM_KSPLIT"], triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(N, config["BLOCK_SIZE_N"]),) + grid = lambda META: ( # noqa: E731 + ( + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * ( + triton.cdiv(N_fp8, META["BLOCK_SIZE_N"]) + + triton.cdiv(N_bf16, META["BLOCK_SIZE_N"]) + ) + ), + ) + _fused_gemm_a8w8_blockscale_a16w16_kernel[grid]( + x_fp8, + w_fp8, + bias_fp8, + x_fp8_scale, + w_fp8_scale, + y_fp8 if config["NUM_KSPLIT"] == 1 else y_fp8_pp, + x_bf16, + w_bf16, + bias_bf16, + y_bf16 if config["NUM_KSPLIT"] == 1 else y_bf16_pp, + M, + N_fp8, + N_bf16, + K, + x_fp8.stride(0), + x_fp8.stride(1), + w_fp8.stride(0), + w_fp8.stride(1), + x_fp8_scale.stride(0), + x_fp8_scale.stride(1), + w_fp8_scale.stride(0), + w_fp8_scale.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_fp8_pp.stride(0), + y_fp8.stride(0) if config["NUM_KSPLIT"] == 1 else y_fp8_pp.stride(1), + y_fp8.stride(1) if config["NUM_KSPLIT"] == 1 else y_fp8_pp.stride(2), + x_bf16.stride(0), + x_bf16.stride(1), + w_bf16.stride(0), + w_bf16.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_bf16_pp.stride(0), + y_bf16.stride(0) if config["NUM_KSPLIT"] == 1 else y_bf16_pp.stride(1), + y_bf16.stride(1) if config["NUM_KSPLIT"] == 1 else y_bf16_pp.stride(2), + ADD_BIAS_FP8=(bias_fp8 is not None), + ADD_BIAS_BF16=(bias_bf16 is not None), + SKIP_REDUCE=skip_reduce, + **config, + ) + + if config["NUM_KSPLIT"] > 1: + if skip_reduce: + return y_fp8_pp, y_bf16_pp + REDUCE_BLOCK_SIZE_M = 32 + REDUCE_BLOCK_SIZE_N = 32 + ACTUAL_KSPLIT = triton.cdiv(K, config["SPLITK_BLOCK_SIZE"]) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N_fp8, REDUCE_BLOCK_SIZE_N) + + triton.cdiv(N_bf16, REDUCE_BLOCK_SIZE_N), + ) + _fused_gemm_a8w8_blockscale_a16w16_reduce_kernel[grid_reduce]( + bias_fp8, + y_fp8_pp, + y_fp8, + bias_bf16, + y_bf16_pp, + y_bf16, + M, + N_fp8, + N_bf16, + y_fp8_pp.stride(0), + y_fp8_pp.stride(1), + y_fp8_pp.stride(2), + y_fp8.stride(0), + y_fp8.stride(1), + y_bf16_pp.stride(0), + y_bf16_pp.stride(1), + y_bf16_pp.stride(2), + y_bf16.stride(0), + y_bf16.stride(1), + REDUCE_BLOCK_SIZE_M, + REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT, + triton.next_power_of_2(config["NUM_KSPLIT"]), + ADD_BIAS_FP8=(bias_fp8 is not None), + ADD_BIAS_BF16=(bias_bf16 is not None), + ) + + return y_fp8, y_bf16 diff --git a/aiter/ops/triton/fused_kv_cache.py b/aiter/ops/triton/fused_kv_cache.py index 13da3caf1f..9b92f07563 100644 --- a/aiter/ops/triton/fused_kv_cache.py +++ b/aiter/ops/triton/fused_kv_cache.py @@ -1,7 +1,8 @@ import torch import triton -import triton.language as tl from aiter.ops.triton._triton_kernels.fused_kv_cache import ( + _fused_qk_rope_cat_and_cache_mla_kernel, + _fused_qk_rope_reshape_and_cache_kernel, _fused_qk_rope_cosine_cache_llama_kernel, ) from aiter.ops.triton.utils.logger import AiterTritonLogger @@ -9,6 +10,342 @@ _LOGGER = AiterTritonLogger() +def fused_qk_rope_cat_and_cache_mla( + q_nope: torch.Tensor, + q_pe: torch.Tensor, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + pos: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + k_scale: torch.Tensor, + is_neox: bool, + num_decode_toks_for_zeros: int = 0, + apply_scale: bool = True, + q_out: torch.Tensor = None, + decode_q_pe_out: torch.Tensor = None, + k_pe_out: torch.Tensor = None, + q_out_dtype=None, +): + """ + Perform RoPE on q_pe and k_pe and concat q_nope with q_pe and k_nope with k_pe along the last dimension + the concatentaed k_nope and k_pe are copied to kv_cache inplace + + Key parameters: + - q_nope: Matrix X with shape (B, QH, D1). + - q_pe: Matrix W with shape (B, QH, D2). + - k_nope: Matrix X with shape (B_slot, KH, D1). + - k_pe: Matrix W with shape (B_slot, KH, D2). + - kv_cache: Matrix W with shape (B_cache, KH, D1 + D2). + - slot_mapping: Matrix W with shape (B_slot, ). + + B is the number of decode tokens, B_slot is the number of prefill + decode tokens, B_cahce is the max number of tokens of kv_cache + QH must be multiple of KH + + Returns: + - q_out: The output matrix with shape (B, QH, D1+D2). + - kv_cache: The output matrix with shape (B_max, KH, D1 + D2) (inplace). + """ + _LOGGER.info( + f"FUSED_QK_ROPE_CAT_AND_CACHE_MLA: q_nope={tuple(q_nope.shape)} q_pe={tuple(q_pe.shape)} k_nope={tuple(k_nope.shape)} k_pe={tuple(k_pe.shape)} " + + f"pos={tuple(pos.shape)} cos={tuple(cos.shape)} sin={tuple(sin.shape)} kv_cache={tuple(kv_cache.shape)} slot_mapping={tuple(slot_mapping.shape)}" + ) + + b, qh, d_nope = q_nope.shape + b2, qh2, d_pe = q_pe.shape + bk, kh, dk_nope = k_nope.shape + bk2, kh2, dk2 = k_pe.shape + b_cache, h_cache, d_cache = kv_cache.shape + (b_slot,) = slot_mapping.shape + + assert ( + b_slot <= b and b == b2 == bk == bk2 + ), "batch dimension should be identical for q_nope, q_pe, k_nope, and k_pe, and the batch dimeion of slot_mapping should be no more than that of q_nope, q_pe, k_nope, and k_pe" + assert qh == qh2, "Q head should be identical" + assert kh == kh2 == h_cache, "K head should be identical" + assert d_pe == dk2, "D dimension of q_pe and k_pe should be identical" + assert ( + dk_nope + dk2 == d_cache + ), "D dimension of k_nope and k_pe should be summed up to be the D dimension of kv_cache" + assert qh % kh == 0, "Q heads must be multiple of H heads" + d_freq = cos.shape[-1] + assert (d_freq == d_pe // 2) or ( + d_freq == d_pe + ), "cos/sin last dim should be the same or half of the qk last dim" + if isinstance(k_scale, torch.Tensor): + assert k_scale.numel() == 1, "k_scale should be a single-element torch.Tensor" + reuse_freqs_front_part = d_freq == d_pe // 2 + + if q_out is None: + q_out = torch.empty( + (b, qh, d_nope + d_pe), + dtype=q_out_dtype if q_out_dtype is not None else q_nope.dtype, + device=q_nope.device, + ) + else: + b_q_out, qh_q_out, d_q_out = q_out.shape + assert ( + b == b_q_out and qh == qh_q_out and d_nope + d_pe == d_q_out + ), "q_out shape mismatch" + + if decode_q_pe_out is None: + decode_q_pe_out = torch.empty( + (num_decode_toks_for_zeros, qh, d_pe), + dtype=q_nope.dtype, + device=q_nope.device, + ) + else: + b_decode_q_pe_out, qh_decode_q_pe_out, d_decode_q_pe_out = decode_q_pe_out.shape + assert ( + num_decode_toks_for_zeros == b_decode_q_pe_out + and qh == qh_decode_q_pe_out + and d_pe == d_decode_q_pe_out + ), "decode_q_pe_out shape mismatch" + + if k_pe_out is None: + k_pe_out = torch.empty((b, kh, d_pe), dtype=k_pe.dtype, device=k_pe.device) + else: + b_k_pe_out, hk_k_pe_out, d_k_pe_out = k_pe_out.shape + assert ( + b == b_k_pe_out and kh == hk_k_pe_out and d_pe == d_k_pe_out + ), "k_pe_out shape mismatch" + + q_nope_zeros_out = None + if num_decode_toks_for_zeros > 0: + q_nope_zeros_out = torch.empty( + (num_decode_toks_for_zeros, qh, dk_nope), + dtype=q_nope.dtype, + device=q_nope.device, + ) + + n_pid = b * qh + (b_slot - b) * kh + grid = (n_pid, 1, 1) + _fused_qk_rope_cat_and_cache_mla_kernel[grid]( + q_nope, + q_pe, + k_nope, + k_pe, + pos, + cos, + sin, + q_out, + decode_q_pe_out, + k_pe_out, + q_nope_zeros_out, + kv_cache, + slot_mapping, + b, + b_slot, + num_decode_toks_for_zeros, + *q_nope.stride(), + *q_pe.stride(), + *k_nope.stride(), + *k_pe.stride(), + pos.stride(0), + cos.stride(0), + cos.stride(-1), + *q_out.stride(), + *decode_q_pe_out.stride(), + *k_pe_out.stride(), + q_nope_zeros_out.stride(0) if q_nope_zeros_out is not None else 0, + q_nope_zeros_out.stride(1) if q_nope_zeros_out is not None else 0, + q_nope_zeros_out.stride(2) if q_nope_zeros_out is not None else 0, + *kv_cache.stride(), + k_scale_ptr=k_scale, + QH_PER_KH=qh // kh, + QH=qh, + KH=kh, + REUSE_FREQS_FRONT_PART=reuse_freqs_front_part, + IS_NEOX=is_neox, + BLOCK_D_nope=d_nope, + BLOCK_DK_nope=dk_nope, + BLOCK_D_pe=d_pe, + BLOCK_D_HALF_pe=d_pe // 2, + OUTPUT_Q_NOPE_ZEROS=(q_nope_zeros_out is not None), + HAVE_K_SCALE=(k_scale is not None and apply_scale), + num_warps=1, + ) + + if num_decode_toks_for_zeros > 0: + return q_out, decode_q_pe_out, k_pe_out, kv_cache, q_nope_zeros_out + return q_out, decode_q_pe_out, k_pe_out, kv_cache + + +def fused_qk_rope_reshape_and_cache( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + pos: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + is_neox: bool, + flash_layout: bool, + apply_scale: bool = True, + offs: torch.Tensor = None, + q_out: torch.Tensor = None, + k_out: torch.Tensor = None, + output_zeros: bool = True, + zeros_out: torch.Tensor = None, +): + """ + Perform RoPE on q and k and along the last dimension and copy k and v in to key_cache and value_cache inplace + + Key parameters: + - q: shape (T, QH, D). + - k: shape (T_slot, KH, D). + - v: shape (T_slot, KH, D). + - if flash_layout: + - key_cache: shape (T_cache, block_size, KH, D). + - value_cache: shape (T_cache, block_size, KH, D). + - else: + - key_cache: shape (T_cache, KH, D // x, block_size, x). + - value_cache: shape (T_cache, KH, D, block_size). + - slot_mapping: shape (T_slot, ). + + T is the number of decode tokens, T_cahce * block_size is the max number of tokens of kv_cache + QH must be multiple of KH + + Returns: + - q_out: same shape as input q. + - k_out: same shape as input k. + - key_cache: same shape as input key_cache (inplace). + - value_cache: same shape as input value_cache (inplace). + - zeros_out: same shape as input q. + """ + _LOGGER.info( + f"FUSED_QK_ROPE_RESHAPE_AND_CACHE: q={tuple(q.shape)} k={tuple(k.shape)} " + + f"pos={tuple(pos.shape)} cos={tuple(cos.shape)} sin={tuple(sin.shape)} key_cache={tuple(key_cache.shape)} value_cache={tuple(value_cache.shape)} slot_mapping={tuple(slot_mapping.shape)}" + ) + + t, qh, d = q.shape + tk, kh, dk = k.shape + tv, vh, dv = v.shape + if flash_layout: + t_cache, block_size, kh_cache, dk_cache = key_cache.shape + t_cache_v, block_size_v, vh_cache, dv_cache = value_cache.shape + else: + t_cache, kh_cache, dkx_cache, block_size, x_cache = key_cache.shape + t_cache_v, vh_cache, dv_cache, block_size_v = value_cache.shape + (t_slot,) = slot_mapping.shape + + assert ( + t == tk == tv and t_slot <= tk + ), f"Number of tokens should be identical for q, kand v. The number of tokens of slot_mapping should no more than that of q, k and v, {t=} {tk=} {tv=} {t_slot=}" + assert ( + block_size == block_size_v + ), f"block size should be identical for key_cache, and value_cache {block_size} {block_size_v}" + assert ( + kh == vh == kh_cache == vh_cache + ), "KV head should be identical for k, v, key_cache, and value_cache" + assert ( + t_cache == t_cache_v + ), "Number of tokens should be identical for key_cache, and value_cache" + if flash_layout: + assert ( + d == dk == dv == dk_cache == dv_cache + ), "D dimension should be identical for q, k, and v" + else: + assert ( + d == dk == dv == dkx_cache * x_cache == dv_cache + ), "D dimension should be identical for q, k, and v" + assert x_cache == triton.next_power_of_2(x_cache), "x_size should be power of 2" + + assert d == triton.next_power_of_2(d), "D dimension should be power of 2" + assert block_size == triton.next_power_of_2( + block_size + ), "block_size should be power of 2" + assert qh % kh == 0, "Q heads must be multiple of H heads" + d_freq = cos.shape[-1] + assert (d_freq == d // 2) or ( + d_freq == d + ), "cos/sin last dim should be the same or half of the qk last dim" + reuse_freqs_front_part = d_freq == d // 2 + + if q_out is None: + q_out = torch.empty((t, qh, d), dtype=q.dtype, device=q.device) + + if k_out is None: + k_out = torch.empty((tk, kh, dk), dtype=k.dtype, device=q.device) + + if zeros_out is not None: + tz, qhz, dz = zeros_out.shape + assert ( + t == tz and qh == qhz and d == dz + ), f"q and zeros shape mismatch {q.shape=} {zeros_out.shape=}" + output_zeros = True + elif output_zeros: + zeros_out = torch.empty((t, qh, d), dtype=q.dtype, device=q.device) + else: + zeros_out = None + + n_pid = t * qh + (t_slot - t) * kh + grid = (n_pid, 1, 1) + _fused_qk_rope_reshape_and_cache_kernel[grid]( + q, + k, + v, + pos, + cos, + sin, + offs, + key_cache, + value_cache, + slot_mapping, + q_out, + k_out, + zeros_out, + t, + t_slot, + *q.stride(), + *k.stride(), + *v.stride(), + cos.stride(0), + cos.stride(-1), + *q_out.stride(), + *k_out.stride(), + key_cache.stride(0) if not flash_layout else key_cache.stride(0), + key_cache.stride(1) if not flash_layout else key_cache.stride(2), + key_cache.stride(2) if not flash_layout else key_cache.stride(3), + key_cache.stride(3) if not flash_layout else key_cache.stride(1), + key_cache.stride(4) if not flash_layout else 0, + value_cache.stride(0) if not flash_layout else value_cache.stride(0), + value_cache.stride(1) if not flash_layout else value_cache.stride(2), + value_cache.stride(2) if not flash_layout else value_cache.stride(3), + value_cache.stride(3) if not flash_layout else value_cache.stride(1), + zeros_out.stride(0) if zeros_out is not None else 0, + zeros_out.stride(1) if zeros_out is not None else 0, + zeros_out.stride(2) if zeros_out is not None else 0, + k_scale_ptr=k_scale, + v_scale_ptr=v_scale, + QH_PER_KH=qh // kh, + QH=qh, + KH=kh, + REUSE_FREQS_FRONT_PART=reuse_freqs_front_part, + IS_NEOX=is_neox, + BLOCK_D_pe=d, + BLOCK_D_HALF_pe=d // 2, + BLOCK_SIZE=block_size, + X_SIZE=x_cache if not flash_layout else 0, + FLASH_LAYOUT=flash_layout, + HAVE_POS=(offs is not None), + HAVE_K_SCALE=(k_scale is not None and apply_scale), + HAVE_V_SCALE=(v_scale is not None and apply_scale), + HAVE_ZEROS=output_zeros, + num_warps=1, + ) + + if zeros_out is not None: + return q_out, k_out, key_cache, value_cache, zeros_out + return q_out, k_out, key_cache, value_cache + + def fused_qk_rope_cosine_cache_llama( q: torch.Tensor, k: torch.Tensor, diff --git a/aiter/ops/triton/fused_mul_add.py b/aiter/ops/triton/fused_mul_add.py index ff7bb36a1f..c756ffc94c 100644 --- a/aiter/ops/triton/fused_mul_add.py +++ b/aiter/ops/triton/fused_mul_add.py @@ -30,8 +30,7 @@ def fused_mul_add( - out: same shape as x """ _LOGGER.info( - f"FUSED_MUL_ADD: x={tuple(x.shape)} a={tuple(a.shape) if isinstance(a, torch.Tensor) else a} " - + f"b={tuple(b.shape) if isinstance(b, torch.Tensor) else b}" + f"FUSED_MUL_ADD: x={tuple(x.shape)} a={tuple(a.shape) if isinstance(a, torch.Tensor) else a} b={tuple(b.shape) if isinstance(b, torch.Tensor) else b}" ) N = x.numel() diff --git a/aiter/ops/triton/fused_qkv_split_qk_rope.py b/aiter/ops/triton/fused_qkv_split_qk_rope.py new file mode 100644 index 0000000000..5a667bb656 --- /dev/null +++ b/aiter/ops/triton/fused_qkv_split_qk_rope.py @@ -0,0 +1,89 @@ +import torch +import triton +from aiter.ops.triton._triton_kernels.fused_qkv_split_qk_rope import ( + _fused_qkv_split_qk_rope_kernel, +) + + +def fused_qkv_split_qk_rope( + qkv: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + positions: torch.Tensor, + qh: int, + kvh: int, + head_dim: int, + is_neox: bool = True, + offsets: torch.Tensor = None, + reuse_freqs_front_part: bool = True, + nope_first: bool = False, +): + T = qkv.shape[0] + q_size = qh * head_dim + kv_size = kvh * head_dim + + assert qh >= kvh and qh % kvh == 0, "qh must be mutiple of kvh" + + q = torch.empty((qkv.shape[0], qh, head_dim), dtype=qkv.dtype, device=qkv.device) + k = torch.empty((qkv.shape[0], kvh, head_dim), dtype=qkv.dtype, device=qkv.device) + v = torch.empty((qkv.shape[0], kvh, head_dim), dtype=qkv.dtype, device=qkv.device) + + if cos.shape[-1] == head_dim // 2: + if reuse_freqs_front_part: + have_nope = False + else: + have_nope = True + elif cos.shape[-1] == head_dim // 4: + have_nope = True + else: + have_nope = False + + assert qkv.shape[-1] == q_size + 2 * kv_size, "Shape error" + assert head_dim // ((2 if have_nope else 1)) == triton.next_power_of_2( + head_dim // ((2 if have_nope else 1)) + ), "head_dim should be power of 2" + + if have_nope: + BLOCK_D = head_dim // 2 + BLOCK_D_HALF = head_dim // 4 + else: + BLOCK_D = head_dim + BLOCK_D_HALF = head_dim // 2 + + BLOCK_T = 32 + num_warps = 4 + waves_per_eu = 0 + grid = (triton.cdiv(T, BLOCK_T), qh, 1) + + _fused_qkv_split_qk_rope_kernel[grid]( + qkv, + cos, + sin, + positions, + offsets, + q, + k, + v, + T, + *qkv.stride(), + cos.stride(0), + cos.stride(-1), + *positions.stride(), + *q.stride(), + *k.stride(), + HAVE_NOPE=have_nope, + NOPE_FIRST=nope_first, + REUSE_FREQS_FRONT_PART=reuse_freqs_front_part, + IS_NEOX=is_neox, + HAVE_POS=(positions is not None), + HAVE_OFFS=(offsets is not None), + QH=qh, + KVH=kvh, + BLOCK_T=BLOCK_T, + BLOCK_D=BLOCK_D, + BLOCK_D_HALF=BLOCK_D_HALF, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + ) + + return q, k, v diff --git a/aiter/ops/triton/gemm_a16w16.py b/aiter/ops/triton/gemm_a16w16.py index f162f80030..b3d4f00bd3 100644 --- a/aiter/ops/triton/gemm_a16w16.py +++ b/aiter/ops/triton/gemm_a16w16.py @@ -5,8 +5,6 @@ import torch import triton import triton.language as tl -import aiter.ops.triton.utils._triton.arch_info as arch_info -from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH from aiter.ops.triton._triton_kernels.gemm_a16w16 import ( _gemm_a16_w16_kernel, _gemm_a16w16_reduce_kernel, @@ -21,10 +19,12 @@ def gemm_a16w16( x, w, + bias: Optional[torch.Tensor] = None, dtype: Optional[float] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[dict] = None, activation: Optional[str] = None, + skip_reduce: Optional[bool] = False, ): """ Computes the 16 bit matmul Y = X x W @@ -50,15 +50,17 @@ def gemm_a16w16( N, K = w.shape w = w.T - if y is None: - y = torch.empty((M, N), dtype=dtype, device=x.device) - if config is None: config = _get_config(M, N, K) + if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): + y = torch.empty((M, N), dtype=dtype, device=x.device) + if config["NUM_KSPLIT"] > 1: y_pp = torch.empty( - (config["NUM_KSPLIT"], M, N), dtype=torch.float32, device=y.device + (config["NUM_KSPLIT"], M, N), + dtype=torch.float32, + device=y.device if y is not None else x.device, ) else: y_pp = None @@ -73,6 +75,7 @@ def gemm_a16w16( _gemm_a16_w16_kernel[grid]( x, w, + bias, y if config["NUM_KSPLIT"] == 1 else y_pp, M, N, @@ -86,10 +89,15 @@ def gemm_a16w16( y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), activation=_get_activation_from_str(activation) if activation else "", use_activation=activation is not None, + ADD_BIAS=(bias is not None), + SKIP_REDUCE=skip_reduce, **config, ) if config["NUM_KSPLIT"] > 1: + if skip_reduce: + return y_pp + REDUCE_BLOCK_SIZE_M = 32 REDUCE_BLOCK_SIZE_N = 32 ACTUAL_KSPLIT = triton.cdiv(K, config["SPLITK_BLOCK_SIZE"]) @@ -99,6 +107,7 @@ def gemm_a16w16( triton.cdiv(N, REDUCE_BLOCK_SIZE_N), ) _gemm_a16w16_reduce_kernel[grid_reduce]( + bias, y_pp, y, M, @@ -114,6 +123,7 @@ def gemm_a16w16( triton.next_power_of_2(config["NUM_KSPLIT"]), activation=_get_activation_from_str(activation) if activation else "", use_activation=activation is not None, + ADD_BIAS=(bias is not None), ) return y diff --git a/aiter/ops/triton/gemm_afp4wfp4.py b/aiter/ops/triton/gemm_afp4wfp4.py index c879a91b9e..4011501965 100644 --- a/aiter/ops/triton/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm_afp4wfp4.py @@ -10,9 +10,14 @@ from aiter.ops.triton._triton_kernels.gemm_afp4wfp4 import ( _gemm_afp4_wfp4_kernel, _gemm_afp4_wfp4_kernel_preshuffled_scales, + _gemm_afp4_wfp4_kernel_preshuffled_weight_scales, _gemm_afp4_wfp4_reduce_kernel, _get_config, ) +from .utils.core import AITER_TRITON_CONFIGS_PATH + +import os +from aiter.utility.triton.triton_metadata_redirect import AOTMetadataContext _LOGGER = AiterTritonLogger() @@ -202,7 +207,7 @@ def gemm_afp4wfp4_preshuffled_scales( Key parameters: - - X: Matrix X with shape (M, K). + - X: Matrix X with shape (M, K). M >= 32 is required - W: Matrix W with shape (N, K). - X_scales: Matrix with shape (M // 32, K) - W_scales: Matrix with shape (N // 32, K) @@ -219,6 +224,8 @@ def gemm_afp4wfp4_preshuffled_scales( # Transpose w w = w.T + assert M >= 32, f"M >= 32 is required, but got {M=}" + if y is None: y = torch.empty((M, N), dtype=dtype, device=x.device) @@ -308,7 +315,175 @@ def gemm_afp4wfp4_preshuffled_scales( REDUCE_BLOCK_SIZE_M, REDUCE_BLOCK_SIZE_N, ACTUAL_KSPLIT, - config["NUM_KSPLIT"], + triton.next_power_of_2(config["NUM_KSPLIT"]), + ) + + return y + + +def gemm_afp4wfp4_preshuffled_weight_scales( + x, + w, + x_scales, + w_scales, + dtype: Optional[float] = torch.bfloat16, + y: Optional[torch.Tensor] = None, + config: Optional[dict] = None, +): + """ + Computes the matmul Y = X x W + X and W are e2m1 fp4 tensors. + x_scales and w_scales are e8m0 tensors. + Every 32 elements in the K dimension share one e8m0 scale. + + + Key parameters: + - X: Matrix X with shape (M, K). + - W: Matrix W with shape (N, K). + - X_scales: Matrix with shape (M // 32, K) + - W_scales: Matrix with shape (N // 32, K) + + Returns: + - Y: The output matrix with shape (M, N). + """ + + assert arch_info.is_fp4_avail(), "MXFP4 is not available on your device" + + M, K = x.shape + N, K = w.shape + N = N * 16 + K = K // 16 + + if y is None: + y = torch.empty((M, N), dtype=dtype, device=x.device) + + if config is None: + config = _get_config(M, N, K, True) + + if config["NUM_KSPLIT"] > 1: + SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk( + K, config["BLOCK_SIZE_K"], config["NUM_KSPLIT"] + ) + + config["SPLITK_BLOCK_SIZE"] = SPLITK_BLOCK_SIZE + config["BLOCK_SIZE_K"] = BLOCK_SIZE_K + config["NUM_KSPLIT"] = NUM_KSPLIT + + if _USE_GEMM_SPLITK_BF16: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), dtype=y.dtype, device=y.device + ) + else: + y_pp = torch.empty( + (config["NUM_KSPLIT"], M, N), dtype=torch.float32, device=y.device + ) + else: + config["SPLITK_BLOCK_SIZE"] = 2 * K + y_pp = None + + if config["BLOCK_SIZE_K"] >= 2 * K: + config["BLOCK_SIZE_K"] = triton.next_power_of_2(2 * K) + config["SPLITK_BLOCK_SIZE"] = 2 * K + + config["BLOCK_SIZE_N"] = max(config["BLOCK_SIZE_N"], 32) + if M < 32: + assert ( + config["BLOCK_SIZE_M"] <= 16 + ), "for M < 32, BLOCK_SIZE_M must be 16 or less as x_scale are assumed to be un-shuffled" + else: + assert ( + config["BLOCK_SIZE_M"] >= 32 + ), "for M >= 32, BLOCK_SIZE_M must be 32 or more as x_scale are assumed to be preshuffled" + + grid = lambda META: ( # noqa: E731 + ( + META["NUM_KSPLIT"] + * triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]) + ), + ) + + M_POW2 = triton.next_power_of_2(M) + if M < 32 and M_POW2 > 16: + M_POW2 = 16 + metadata_pth = f"{AITER_TRITON_CONFIGS_PATH}/gemm/aot/{_gemm_afp4_wfp4_kernel_preshuffled_weight_scales.fn.__name__}_M={M_POW2}-N={N}-K={K*2}" + if os.path.exists(metadata_pth): + with AOTMetadataContext( + _gemm_afp4_wfp4_kernel_preshuffled_weight_scales.fn.__name__, + f"{metadata_pth}", + ): + _gemm_afp4_wfp4_kernel_preshuffled_weight_scales[grid]( + x, + w, + y if config["NUM_KSPLIT"] == 1 else y_pp, + x_scales, + w_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w.stride(0), + w.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), + x_scales.stride(0), + x_scales.stride(1), + w_scales.stride(0), + w_scales.stride(1), + **config, + ) + else: + _gemm_afp4_wfp4_kernel_preshuffled_weight_scales[grid]( + x, + w, + y if config["NUM_KSPLIT"] == 1 else y_pp, + x_scales, + w_scales, + M, + N, + K, + x.stride(0), + x.stride(1), + w.stride(0), + w.stride(1), + 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), + y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), + y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), + x_scales.stride(0), + x_scales.stride(1), + w_scales.stride(0), + w_scales.stride(1), + **config, + ) + + if config["NUM_KSPLIT"] > 1: + REDUCE_BLOCK_SIZE_M = 16 + # TODO: Need to debug - REDUCE_BLOCK_SIZE_N=128 with fp32 partials fails + # NOTE: REDUCE_BLOCK_SIZE_N=16 gives best perf with fp32 partials and + # REDUCE_BLOCK_SIZE_N=128 gives best perf with bf16 partials + REDUCE_BLOCK_SIZE_N = 128 if _USE_GEMM_SPLITK_BF16 else 64 + ACTUAL_KSPLIT = triton.cdiv(K, (config["SPLITK_BLOCK_SIZE"] // 2)) + + grid_reduce = ( + triton.cdiv(M, REDUCE_BLOCK_SIZE_M), + triton.cdiv(N, REDUCE_BLOCK_SIZE_N), + ) + _gemm_afp4_wfp4_reduce_kernel[grid_reduce]( + y_pp, + y, + M, + N, + y_pp.stride(0), + y_pp.stride(1), + y_pp.stride(2), + y.stride(0), + y.stride(1), + REDUCE_BLOCK_SIZE_M, + REDUCE_BLOCK_SIZE_N, + ACTUAL_KSPLIT, + triton.next_power_of_2(config["NUM_KSPLIT"]), ) return y diff --git a/aiter/ops/triton/split_qkv.py b/aiter/ops/triton/split_qkv.py new file mode 100644 index 0000000000..70326168bb --- /dev/null +++ b/aiter/ops/triton/split_qkv.py @@ -0,0 +1,31 @@ +import torch +from aiter.ops.triton._triton_kernels.split_qkv import _split_qkv_kernel + + +def split_qkv( + qkv, + q_size, + kv_size, +): + + q = torch.empty(qkv.shape[0], q_size, dtype=qkv.dtype, device=qkv.device) + k = torch.empty(qkv.shape[0], kv_size, dtype=qkv.dtype, device=qkv.device) + v = torch.empty(qkv.shape[0], kv_size, dtype=qkv.dtype, device=qkv.device) + + grid = qkv.shape[0] + + # TODO: Add support for dim + _split_qkv_kernel[(grid,)]( + qkv, + q, + k, + v, + qkv.stride(0), + q.stride(0), + k.stride(0), + v.stride(0), + q_size, + kv_size, + ) + + return q, k, v diff --git a/aiter/ops/triton/unified_attention.py b/aiter/ops/triton/unified_attention.py new file mode 100644 index 0000000000..b2231ee563 --- /dev/null +++ b/aiter/ops/triton/unified_attention.py @@ -0,0 +1,332 @@ +# The kernels in this file are adapted from vLLM: +# https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_unified_attention.py +import triton +import torch +from aiter.ops.triton.utils.device_info import get_num_sms +import math +from aiter.ops.triton._triton_kernels.unified_attention import ( + kernel_unified_attention_2d, + kernel_unified_attention_3d, + reduce_segments, +) + + +def select_2d_config( + block_size, + head_size, + sliding_window, + all_decode, + max_seqlen_q, + max_seqlen_k, + num_queries_per_kv, + num_2d_prgms, +): + BLOCK_M = ( + 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + ) + TILE_SIZE = 64 + # in case head_size is large + max_num_stages_2d = 4 + if head_size > 128: + max_num_stages_2d = 2 + if all_decode == False: + num_stages_2d = 1 + num_warps = 2 + else: + num_stages_2d = 3 + num_warps = 2 + TILE_SIZE = block_size + + if max_seqlen_q >= 256: + BLOCK_M = 128 + num_stages_2d = 1 + num_warps = 4 + BLOCK_Q = BLOCK_M // num_queries_per_kv + num_stages_2d = min(max_num_stages_2d, num_stages_2d) + return { + "BLOCK_M": BLOCK_M, + "BLOCK_Q": BLOCK_Q, + "TILE_SIZE": TILE_SIZE, + "num_warps": num_warps, + "num_stages": num_stages_2d, + "waves_per_eu": 2, + } + + +def select_3d_config( + head_size, block_size, element_size, max_seqlen_k, target_num_prgms, num_2d_prgms +): + reduce_num_warps = 2 + attn_warps = 2 + TILE_SIZE = block_size + MAX_SEGMENTS = min(128, math.ceil(max_seqlen_k / TILE_SIZE)) + num_segments = math.ceil(target_num_prgms / num_2d_prgms) + num_segments = triton.next_power_of_2(num_segments) + num_segments = min(num_segments, 128) + MIN_SEGMENTS = 16 if TILE_SIZE <= 16 else 8 + num_segments = max(num_segments, MIN_SEGMENTS) + if num_segments == MIN_SEGMENTS: + reduce_num_warps = 1 + attn_config = { + "TILE_SIZE": TILE_SIZE, + "NUM_SEGMENTS_PER_SEQ": num_segments, + "num_warps": attn_warps, + "num_stages": 1, + "waves_per_eu": 2, + } + reduce_config = { + "TILE_SIZE": TILE_SIZE, + "NUM_SEGMENTS_PER_SEQ": num_segments, + "num_warps": reduce_num_warps, + "num_stages": 1, + "waves_per_eu": 2, + } + return attn_config, reduce_config + + +def use_2d_kernel( + head_size, + sliding_window, + all_decode, + max_seqlen_q, + max_seqlen_k, + target_num_prgms, + num_2d_prgms, +): + return ( + (sliding_window > 0) + or (max_seqlen_k <= 512) + or (num_2d_prgms > target_num_prgms) + ) + + +def unified_attention( + q, + k, + v, + out, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + softmax_scale, + causal, + window_size, + block_table, + softcap, + q_descale, + k_descale, + v_descale, + alibi_slopes=None, + output_scale=None, + qq_bias=None, + # Optional tensor for sinks + sinks=None, +): + assert causal, "Only causal attention is supported" + assert q_descale is None, "Q scales not supported" + + if sinks is not None: + assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size" + + use_alibi_slopes = alibi_slopes is not None + use_qq_bias = qq_bias is not None + SLIDING_WINDOW = 1 + window_size[0] + + block_size = v.shape[1] + num_seqs = len(seqused_k) + num_query_heads = q.shape[1] + num_kv_heads = k.shape[2] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = q.shape[2] + + BLOCK_M = ( + 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + ) + BLOCK_Q = BLOCK_M // num_queries_per_kv + assert BLOCK_Q >= 1 + # Ideally we would launch with kernel with: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. + # However, it is slow to realize the query_lens on cpu. + # Instead we use upper-bound: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] + # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] + # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs + # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs + # = floor(q.shape[0] / BLOCK_Q) + num_seqs + cu_count = get_num_sms() + total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + target_num_prgms = cu_count * 4 + num_2d_prgms = total_num_q_blocks * num_kv_heads + ALL_DECODE = max_seqlen_q == 1 + # if batch contains a prefill + if use_2d_kernel( + head_size, + SLIDING_WINDOW, + ALL_DECODE, + max_seqlen_q, + max_seqlen_k, + target_num_prgms, + num_2d_prgms, + ): + config = select_2d_config( + block_size, + head_size, + SLIDING_WINDOW, + ALL_DECODE, + max_seqlen_q, + max_seqlen_k, + num_queries_per_kv, + num_2d_prgms, + ) + assert config["BLOCK_Q"] >= 1 + total_num_q_blocks = q.shape[0] // config["BLOCK_Q"] + num_seqs + + kernel_unified_attention_2d[ + ( + num_kv_heads, + total_num_q_blocks, + ) + ]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + out_scale=1 / output_scale if output_scale is not None else 1.0, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), + SLIDING_WINDOW=SLIDING_WINDOW, + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + num_seqs=num_seqs, + USE_FP8=output_scale is not None, + ALL_DECODE=ALL_DECODE, + **config, + ) + + else: + attn_config, reduce_config = select_3d_config( + head_size, + block_size, + q.element_size(), + max_seqlen_k, + target_num_prgms, + num_2d_prgms, + ) + NUM_SEGMENTS = attn_config["NUM_SEGMENTS_PER_SEQ"] + segm_output = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + triton.next_power_of_2(head_size), + dtype=torch.float32, + device=q.device, + ) + segm_max = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) + segm_expsum = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) + + kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), + SLIDING_WINDOW=SLIDING_WINDOW, + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + ALL_DECODE=ALL_DECODE, + **attn_config, + ) + reduce_segments[(q.shape[0], num_query_heads)]( + output_ptr=out, + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + seq_lens_ptr=seqused_k, + num_seqs=num_seqs, + num_query_heads=num_query_heads, + out_scale_inv=1 / output_scale if output_scale is not None else 1.0, + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + block_table_stride=block_table.stride(0), + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + USE_FP8=output_scale is not None, + **reduce_config, + ) diff --git a/op_tests/op_benchmarks/triton/bench_gemm_a16w16.py b/op_tests/op_benchmarks/triton/bench_gemm_a16w16.py index 0eef93098c..b401be6a4e 100644 --- a/op_tests/op_benchmarks/triton/bench_gemm_a16w16.py +++ b/op_tests/op_benchmarks/triton/bench_gemm_a16w16.py @@ -32,8 +32,8 @@ def bench_gemm_fn( ): # NOTE: Assume bias and output has the same dtype c_dtype = torch.bfloat16 - x, w, out_dtype, y = generate_gemm_a16w16_inputs( - M, N, K, c_dtype, layout=layout, output=True + x, w, bias, out_dtype, y = generate_gemm_a16w16_inputs( + M, N, K, c_dtype, layout=layout, output=True, bias=True ) # flops flops = 2.0 * M * N * K @@ -57,7 +57,7 @@ def bench_gemm_fn( ) else: ms = triton.testing.do_bench( - lambda: gemm_a16w16(x, w, c_dtype, y, activation=activation), + lambda: gemm_a16w16(x, w, bias, c_dtype, y, activation=activation), warmup=25, rep=100, # noqa: E731 ) diff --git a/op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py b/op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py index 4d988e89e4..45f1a7fce0 100644 --- a/op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py +++ b/op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py @@ -6,6 +6,7 @@ from aiter.ops.triton.gemm_afp4wfp4 import ( gemm_afp4wfp4, gemm_afp4wfp4_preshuffled_scales, + gemm_afp4wfp4_preshuffled_weight_scales, ) from op_tests.triton_tests.test_gemm_afp4wfp4 import generate_gemm_afp4wfp4_inputs from op_tests.op_benchmarks.triton.utils.argparse import ( @@ -17,24 +18,21 @@ get_model_benchmark_object, get_shape_benchmark_object, print_vgpr, - get_caller_name_no_ext, ) import aiter.ops.triton.utils._triton.arch_info as arch_info -TRITON_HIP_PRESHUFFLE_SCALES = ( - os.environ.get("TRITON_HIP_PRESHUFFLE_SCALES", "0") == "1" -) - -def bench_gemm_fn(M: int, N: int, K: int, metric: str, layout: str): +def bench_gemm_fn(M: int, N: int, K: int, metric: str, layout: str, shuffle: bool): c_dtype = torch.bfloat16 - x, w, _, _, x_scale, w_scale, _, y = generate_gemm_afp4wfp4_inputs( + x, _, w, _, _, x_scale, w_scale, _, y = generate_gemm_afp4wfp4_inputs( M, N, K, c_dtype, layout=layout, output=True, + shuffle_scales_fg=shuffle, + shuffle_weight_fg=shuffle, ) # flops flops = 2.0 * M * N * K @@ -46,11 +44,10 @@ def bench_gemm_fn(M: int, N: int, K: int, metric: str, layout: str): ) mem_write = (M * N) * 2 # TODO: Fix for c_dtype != bf16 mem = mem_read + mem_write - - if TRITON_HIP_PRESHUFFLE_SCALES: + if shuffle: ms = triton.testing.do_bench( - lambda: gemm_afp4wfp4_preshuffled_scales( - x, w, x_scale, w_scale, c_dtype, y + lambda: gemm_afp4wfp4_preshuffled_weight_scales( + x, w, x_scale, w_scale, c_dtype, y # , config=config ), warmup=25, rep=100, @@ -101,7 +98,7 @@ def run_benchmark(args, defaults): def run_model_benchmark(args): - benchmark = get_model_benchmark_object(get_caller_name_no_ext(), args) + benchmark = get_model_benchmark_object("GEMM MXFP4 x MXFP4 Benchmark", args) @triton.testing.perf_report([benchmark]) def bench_gemm_afp4wfp4( @@ -119,17 +116,17 @@ def bench_gemm_afp4wfp4( # Divide K by tensor parallel K = math.ceil(K / args.tp) - return bench_gemm_fn(M, N, K, metric, args.layout) + return bench_gemm_fn(M, N, K, metric, args.layout, args.shuffle) bench_gemm_afp4wfp4.run(save_path="." if args.o else None, print_data=True) def run_shape_benchmark(args): - benchmark = get_shape_benchmark_object(get_caller_name_no_ext(), args) + benchmark = get_shape_benchmark_object("GEMM MXFP4 x MXFP4 Benchmark", args) @triton.testing.perf_report([benchmark]) def bench_gemm_afp4wfp4(M, N, K, metric, model_name=None, **kwargs): - return bench_gemm_fn(M, N, K, metric, args.layout) + return bench_gemm_fn(M, N, K, metric, args.layout, args.shuffle) bench_gemm_afp4wfp4.run(save_path="." if args.o else None, print_data=True) @@ -137,6 +134,9 @@ def bench_gemm_afp4wfp4(M, N, K, metric, model_name=None, **kwargs): def parse_args(): parser = get_parser("MXFP4 x MXFP4 GEMM") parser = add_argparse_ff(parser) + parser.add_argument( + "--shuffle", action="store_true", help="Preshuffle weight and scales" + ) return get_ff_args(parser) @@ -149,7 +149,7 @@ def main(): if args.print_vgpr: print("Retrieving VGPR usage for Triton kernels...") fun = lambda: run_benchmark(args, defaults) # noqa: E731 - print_vgpr(fun, get_caller_name_no_ext()) + print_vgpr(fun, "GEMM") return 0 run_benchmark(args, defaults) diff --git a/op_tests/triton_tests/test_fused_add_rmsnorm_pad.py b/op_tests/triton_tests/test_fused_add_rmsnorm_pad.py new file mode 100644 index 0000000000..ebf29d8127 --- /dev/null +++ b/op_tests/triton_tests/test_fused_add_rmsnorm_pad.py @@ -0,0 +1,53 @@ +import torch +import pytest +from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad +import torch.nn.functional as F + + +def generate_inputs(M, N, has_res, dtype): + x = torch.randn((M, N), dtype=dtype, device="cuda") + weight = torch.randn((N,), dtype=dtype, device="cuda") + res = torch.randn((M, N), dtype=dtype, device="cuda") if has_res else None + return x, weight, res + + +def run_torch(x, weight, eps=1e-6, res=None, pad_to_multiple=0): + dtype = x.dtype + x = x.to(torch.float32) + if res is not None: + x = x + res.to(torch.float32) + res = x.to(dtype) + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * weight * torch.rsqrt(variance + eps) + N = x.shape[-1] + if pad_to_multiple > 0: + pad = (N + pad_to_multiple - 1) // pad_to_multiple * pad_to_multiple - N + if pad > 0: + x = F.pad(x, (0, pad, 0, 0), "constant", 0.0) + x = x.to(dtype) + if res is not None: + return x, res + return x + + +@pytest.mark.parametrize("M", [1, 4, 8, 16, 32, 256, 8192]) +@pytest.mark.parametrize("N", [4, 16, 320, 640, 2880]) +@pytest.mark.parametrize("has_res", [False, True]) +@pytest.mark.parametrize("pad_to_multiple", [0, 256]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_mul_add(M: int, N: int, has_res: bool, pad_to_multiple: int, dtype): + + x, weight, res = generate_inputs(M, N, has_res, dtype) + + if has_res: + x_torch, res_torch = run_torch(x, weight, 1e-6, res, pad_to_multiple) + x_triton, res_triton = fused_add_rmsnorm_pad( + x, weight, 1e-6, res, pad_to_multiple + ) + else: + x_torch = run_torch(x, weight, 1e-6, res, pad_to_multiple) + x_triton = fused_add_rmsnorm_pad(x, weight, 1e-6, res, pad_to_multiple) + + torch.testing.assert_close(x_torch.to(dtype), x_triton) + if has_res: + torch.testing.assert_close(res_torch.to(dtype), res_triton) diff --git a/op_tests/triton_tests/test_fused_fp8_quant.py b/op_tests/triton_tests/test_fused_fp8_quant.py new file mode 100644 index 0000000000..9621756f4f --- /dev/null +++ b/op_tests/triton_tests/test_fused_fp8_quant.py @@ -0,0 +1,219 @@ +import torch +import pytest +from aiter.ops.triton.fused_fp8_quant import ( + fused_rms_fp8_group_quant, + fused_flatten_fp8_group_quant, + fused_reduce_act_mul_fp8_group_quant, +) +from op_tests.triton_tests.test_quant_mxfp4 import torch_dynamic_mxfp4_quant +import aiter +import torch.nn.functional as F + +torch.manual_seed(0) + + +def rmsnorm(input, weight, eps=1e-6): + row_norm = input * input + row_norm = torch.sum(row_norm, dim=-1) + norm_factor = torch.rsqrt((row_norm / input.shape[1]) + eps) + rms_norm = input * norm_factor[:, None] * weight[None, :] + return rms_norm + + +def per_token_fp8_group_quant(x, dtype_quant, group_size=128): + DTYPE_MAX = torch.finfo(dtype_quant).max + M, N = x.shape + x_reshape = x.reshape(M, N // group_size, group_size).to(torch.float32) + x_max = torch.max(torch.abs(x_reshape), dim=-1, keepdim=True)[0] + x_max = torch.where(x_max < 1e-10, 1e-10, x_max).to(torch.float32) + x_scale = x_max / DTYPE_MAX + scale_recip = 1.0 / x_scale + x_quant = torch.clamp(x_reshape * scale_recip, -DTYPE_MAX, DTYPE_MAX).to( + dtype_quant + ) + x_quant = x_quant.reshape(M, N) + x_scale = x_scale.squeeze(-1) + + return x_quant, x_scale + + +def upcast(x, s, dtype, group_size=128): + x_N = x.shape[1] + x = x.reshape(-1, x_N // group_size, group_size).to(torch.float32) * s.reshape( + -1, s.shape[1], 1 + ) + x = x.reshape(-1, x_N) + return x.to(dtype=dtype) + + +def run_torch_rms_fp8_group_quant( + x1, w1, eps1, x2, w2, eps2, res1, dtype_quant, group_size +): + s = x1 + res1 + y1 = rmsnorm(s, w1, eps1) + y2 = rmsnorm(x2, w2, eps2) + y1_q, y1_s = per_token_fp8_group_quant(y1, dtype_quant, group_size) + return (y1_q, y1_s), y1.to(x1.dtype), y2.to(x1.dtype), s.to(x1.dtype) + + +def generate_fused_rms_quant_data(M, N1, N2, dtype=torch.bfloat16): + x1 = torch.randn((M, N1), dtype=dtype, device="cuda") / 10 + x2 = torch.randn((M, N2), dtype=dtype, device="cuda") / 10 + w1 = torch.ones((N1,), dtype=torch.float32, device="cuda") + w2 = torch.ones((N2,), dtype=torch.float32, device="cuda") + res1 = torch.randn((M, N1), dtype=dtype, device="cuda") / 10 + return x1, w1, x2, w2, res1 + + +@pytest.mark.parametrize("M", [1, 32, 256]) +@pytest.mark.parametrize("N1, N2", [(128, 128), (128, 7168), (7168, 7168)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_rms_fp8_group_quant(M: int, N1: int, N2: int, dtype): + group_size = 128 + dtype_quant = aiter.dtypes.fp8 + x1, w1, x2, w2, res1 = generate_fused_rms_quant_data(M, N1, N2, dtype) + + (y1_q_torch, y1_s_torch), y1_torch, y2_torch, y1_res_torch = ( + run_torch_rms_fp8_group_quant( + x1, w1, 1e-6, x2, w2, 1e-6, res1, dtype_quant, group_size + ) + ) + + (y1_q_triton, y1_s_triton), y1_triton, y2_triton, y1_res_triton = ( + fused_rms_fp8_group_quant( + x1, + w1, + 1e-6, + inp2=x2, + inp2_weight=w2, + inp2_epsilon=1e-6, + group_size=group_size, + dtype_quant=dtype_quant, + res1=res1, + output_unquantized_inp1=True, + ) + ) + + torch.testing.assert_close(y1_torch, y1_triton, atol=0.1, rtol=0.1) + torch.testing.assert_close(y2_torch, y2_triton, atol=0.1, rtol=0.1) + torch.testing.assert_close(y1_res_torch, y1_res_triton, atol=0.1, rtol=0.1) + + y1_upcast_torch = upcast( + y1_q_torch, y1_s_torch, dtype=torch.float32, group_size=group_size + ) + y1_upcast_triton = upcast( + y1_q_triton, y1_s_triton, dtype=torch.float32, group_size=group_size + ) + torch.testing.assert_close(y1_upcast_torch, y1_upcast_triton, atol=0.1, rtol=0.1) + + +def run_torch_flatten_fp8_group_quant(x, dtype_quant, group_size): + y_q, y_s = per_token_fp8_group_quant( + x.reshape(x.shape[0], -1), dtype_quant, group_size + ) + return y_q, y_s + + +@pytest.mark.parametrize("M", [1, 32, 256]) +@pytest.mark.parametrize("N1, N2", [(16, 128)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_flatten_fp8_group_quant(M: int, N1: int, N2: int, dtype): + group_size = 128 + dtype_quant = aiter.dtypes.fp8 + x = torch.randn((N1, M, N2), dtype=dtype, device="cuda") / 10 + x = x.transpose(0, 1) + + y_q_torch, y_s_torch = run_torch_flatten_fp8_group_quant(x, dtype_quant, group_size) + + y_q_triton, y_s_triton = fused_flatten_fp8_group_quant( + x, + group_size=group_size, + dtype_quant=dtype_quant, + ) + + y_upcast_torch = upcast( + y_q_torch, y_s_torch, dtype=torch.float32, group_size=group_size + ) + y_upcast_triton = upcast( + y_q_triton, y_s_triton, dtype=torch.float32, group_size=group_size + ) + torch.testing.assert_close(y_upcast_torch, y_upcast_triton, atol=0.1, rtol=0.1) + + +def run_torch_reduce_act_mul_fp8_group_quant( + x, x2, activation, dtype, dtype_quant, group_size=128 +): + x = x.clone() + y2 = None + if x.dim() == 3: + x = x.sum(axis=0) + y2 = x2.sum(axis=0).to(dtype=dtype) + else: + assert x2 is None, "x2 must be None in x.dim() == 2 cases" + n = x.shape[1] // 2 + x, x_mul = x.split([n, n], dim=-1) + if activation == "silu": + x = F.silu(x) * x_mul + elif activation == "gelu": + x = F.gelu(x) * x_mul + + y_q, y_s = per_token_fp8_group_quant(x, dtype_quant, group_size) + + return (y_q, y_s), y2 + + +def generate_fused_reduce_act_mul_fp8_group_quant( + M: int, + N1: int, + dtype=torch.bfloat16, + SPK: int = 1, + N2: int = 1, +): + if SPK == 1: + x = torch.randn((M, N1 * 2), dtype=dtype).cuda() / 10 + else: + x = torch.randn((SPK, M, N1 * 2), dtype=torch.float32).cuda() / 10 + x2 = None + if SPK > 1: + x2 = torch.randn((SPK, M, N2), dtype=torch.float32).cuda() / 10 + + return x, x2 + + +@pytest.mark.parametrize("M", [1, 32, 256, 131072]) +@pytest.mark.parametrize("N1, N2", [(256, 256)]) +@pytest.mark.parametrize("SPK", [1, 4, 14]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("activation", ["silu", "gelu"]) +def test_fused_reduce_act_mul_fp8_group_quant( + M: int, N1: int, N2: int, SPK: int, dtype, activation +): + group_size = 128 + dtype_quant = aiter.dtypes.fp8 + + x, x2 = generate_fused_reduce_act_mul_fp8_group_quant( + M, N1, dtype=dtype, SPK=SPK, N2=N2 + ) + + (y_q_torch, y_s_torch), y2_torch = run_torch_reduce_act_mul_fp8_group_quant( + x, x2, activation, dtype, dtype_quant, group_size + ) + + (y_q_triton, y_s_triton), y2_triton = fused_reduce_act_mul_fp8_group_quant( + x, + activation=activation, + x2=x2, + group_size=group_size, + dtype_quant=dtype_quant, + dtype=dtype, + ) + + torch.testing.assert_close(y2_torch, y2_triton, atol=0.1, rtol=0.1) + + y_upcast_torch = upcast( + y_q_torch, y_s_torch, dtype=torch.float32, group_size=group_size + ) + y_upcast_triton = upcast( + y_q_triton, y_s_triton, dtype=torch.float32, group_size=group_size + ) + torch.testing.assert_close(y_upcast_torch, y_upcast_triton, atol=0.1, rtol=0.1) diff --git a/op_tests/triton_tests/test_fused_gemm_a8w8_blockscale_a16w16.py b/op_tests/triton_tests/test_fused_gemm_a8w8_blockscale_a16w16.py new file mode 100644 index 0000000000..13fdf74cc8 --- /dev/null +++ b/op_tests/triton_tests/test_fused_gemm_a8w8_blockscale_a16w16.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import pytest +from aiter.ops.triton.fused_gemm_a8w8_blockscale_a16w16 import ( + fused_gemm_a8w8_blockscale_a16w16, +) +from op_tests.triton_tests.test_gemm_a8w8_blockscale import ( + generate_gemm_a8w8_blockscale_inputs, +) +from op_tests.triton_tests.test_gemm_a8w8_blockscale import run_torch as run_torch_fp8 +from op_tests.triton_tests.test_gemm_a16w16 import generate_gemm_a16w16_inputs +import torch.nn.functional as F + + +block_shape = (128, 128) + + +def run_torch( + x_fp8, + w_fp8, + x_fp8_scale, + w_fp8_scale, + x_bf16, + w_bf16, + bias_fp8, + bias_bf16, + dtype=torch.bfloat16, +): + y_fp8 = run_torch_fp8(x_fp8, w_fp8, x_fp8_scale, w_fp8_scale, dtype) + if bias_fp8 is not None: + y_fp8 += bias_fp8 + y_bf16 = F.linear(x_bf16, w_bf16, bias=bias_bf16) + return y_fp8, y_bf16 + + +def run_triton( + x_fp8, + w_fp8, + x_fp8_scale, + w_fp8_scale, + x_bf16, + w_bf16, + bias_fp8, + bias_bf16, + dtype=torch.bfloat16, + y_fp8=None, + y_bf16=None, + skip_reduce=False, +): + return fused_gemm_a8w8_blockscale_a16w16( + x_fp8, + w_fp8, + x_fp8_scale, + w_fp8_scale, + x_bf16, + w_bf16, + bias_fp8=bias_fp8, + bias_bf16=bias_bf16, + dtype=dtype, + y_fp8=y_fp8, + y_bf16=y_bf16, + skip_reduce=skip_reduce, + ) + + +def get_x_vals(): + + x_vals = [(1, 1, 1, 128)] # minimal case + x_vals += [ + (m, n1, n2, k) + for k in [1024, 8192, 7168] + for n2 in [256, 512] + for n1 in [256, 512] + for m in [1, 8, 32, 64, 128, 8192] + ] + return x_vals + + +@pytest.mark.parametrize("M, N1, N2, K", get_x_vals()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("output", [True, False]) +@pytest.mark.parametrize("skip_reduce", [True, False]) +def test_gemm(dtype, M, N1, N2, K, output, skip_reduce): + block_shape_n, block_shape_k = block_shape + + x_fp8, w_fp8, x_fp8_scale, w_fp8_scale, y_fp8 = ( + generate_gemm_a8w8_blockscale_inputs( + M, + N1, + K, + block_shape_n, + block_shape_k, + dtype=dtype, + output=output, + ) + ) + x_bf16, w_bf16, bias_bf16, _, y_bf16 = generate_gemm_a16w16_inputs( + M, N2, K, dtype, output=output, bias=True + ) + bias_bf16 = torch.randn((N2,), dtype=bias_bf16.dtype, device=bias_bf16.device) + bias_fp8 = torch.randn((N1,), dtype=bias_bf16.dtype, device=bias_bf16.device) + y_torch_fp8, y_torch_bf16 = run_torch( + x_fp8, + w_fp8, + x_fp8_scale, + w_fp8_scale, + x_bf16, + w_bf16, + bias_fp8, + bias_bf16, + dtype, + ) + y_triton_fp8, y_triton_bf16 = run_triton( + x_fp8, + w_fp8, + x_fp8_scale, + w_fp8_scale, + x_bf16, + w_bf16, + bias_fp8, + bias_bf16, + dtype, + y_fp8, + y_bf16, + skip_reduce=skip_reduce, + ) + + if y_triton_fp8.dim() == 3: + y_triton_fp8 = y_triton_fp8.sum(axis=0).to(dtype=dtype) + y_triton_bf16 = y_triton_bf16.sum(axis=0).to(dtype=dtype) + + triton.testing.assert_close(y_torch_fp8, y_triton_fp8, atol=0.1, rtol=1e-1) + triton.testing.assert_close(y_torch_bf16, y_triton_bf16, atol=0.1, rtol=1e-1) diff --git a/op_tests/triton_tests/test_fused_kv_cache.py b/op_tests/triton_tests/test_fused_kv_cache.py index e48d2556d1..51e899cc31 100644 --- a/op_tests/triton_tests/test_fused_kv_cache.py +++ b/op_tests/triton_tests/test_fused_kv_cache.py @@ -6,10 +6,420 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from test_rope import ref_rope_sbhd_fwd, RotateStyle from .test_rope import generate_rope_inputs -from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cosine_cache_llama +from aiter.ops.triton.fused_kv_cache import ( + fused_qk_rope_cat_and_cache_mla, + fused_qk_rope_reshape_and_cache, + fused_qk_rope_cosine_cache_llama, +) from aiter.ops.triton.utils._triton import arch_info +@pytest.mark.parametrize("T", [1, 2, 4, 2048]) +@pytest.mark.parametrize("QH_per_KH", [1, 16]) +@pytest.mark.parametrize("KH", [1, 8]) +@pytest.mark.parametrize("D", [128]) # For now, D is power of 2. D >= 16 +@pytest.mark.parametrize("D_q_nope", [128]) +@pytest.mark.parametrize("D_lora", [512]) +@pytest.mark.parametrize("num_kv_cahce_tokens", [16384]) +@pytest.mark.parametrize("rotate_style", [RotateStyle.GPTJ, RotateStyle.NEOX]) +@pytest.mark.parametrize("reuse_freqs_front_part", [False, True]) +@pytest.mark.parametrize("cache_dtype", [torch.bfloat16, torch.uint8]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_fused_qk_rope_cat_and_cache_mla( + T: int, + QH_per_KH: int, + KH: int, + D: int, + D_q_nope: int, + D_lora: int, + num_kv_cahce_tokens: int, + rotate_style: int, + reuse_freqs_front_part: bool, + cache_dtype: bool, + dtype: torch.dtype, +): + pos = True + _, _, _, _, freqs, positions, offsets, cos, sin = generate_rope_inputs( + 1, + T, + KH, + QH_per_KH, + D, + cached=True, + reuse_freqs_front_part=reuse_freqs_front_part, + nope=False, + pos=pos, + offs=False, + two_inputs=True, + layout="thd", + dtype=dtype, + ) + q = torch.randn((T, QH_per_KH * KH, D_q_nope + D), dtype=dtype, device="cuda") + q_nope, q_pe = q.split((D_q_nope, D), dim=-1) + k_lora = torch.randn((T, KH, D_lora), dtype=dtype, device=q.device) / ( + 20 if cache_dtype == torch.uint8 else 1 + ) + k_pe = torch.randn((T, KH, D), dtype=dtype, device=q.device) / ( + 20 if cache_dtype == torch.uint8 else 1 + ) + + if cache_dtype == torch.uint8: + if arch_info.get_arch() in ["gfx950"]: + cache_dtype_actual = torch.float8_e4m3fn + else: + cache_dtype_actual = torch.float8_e4m3fnuz + + kv_cache = torch.zeros( + (num_kv_cahce_tokens, KH, D_lora + D), dtype=cache_dtype, device="cuda" + ) + + if cache_dtype == torch.uint8: + k_scale = torch.randn( + [ + 1, + ], + dtype=torch.float32, + device="cuda", + )[0] + else: + k_scale = torch.ones( + [ + 1, + ], + dtype=torch.float32, + device="cuda", + )[0] + slot_mapping = torch.randperm(T, device="cuda") + kv_cache_og_dtype = kv_cache.dtype + + ref_freqs = ( + freqs[positions if offsets is None else torch.add(positions, offsets)].squeeze( + -2 + ) + if pos + else freqs + ) + + torch_q_nope = q_nope + torch_q_pe = q_pe + torch_k_lora = k_lora + torch_k_pe = k_pe + + torch_q_pe = ref_rope_sbhd_fwd( + torch_q_pe.unsqueeze(0), + ref_freqs, + rotate_style=rotate_style, + reuse_freqs_front_part=reuse_freqs_front_part, + nope_first=False, + ).squeeze(0) + torch_k_pe = ref_rope_sbhd_fwd( + torch_k_pe.unsqueeze(0), + ref_freqs, + rotate_style=rotate_style, + reuse_freqs_front_part=reuse_freqs_front_part, + nope_first=False, + ).squeeze(0) + + torch_kv_cache = kv_cache.clone() + torch_k_pe_og_dtype = torch_k_pe.clone() + torch_q = torch.cat((torch_q_nope, torch_q_pe), dim=-1) + torch_decode_q_pe = torch_q_pe + if cache_dtype == torch.uint8: + torch_kv_cache = torch_kv_cache.view(cache_dtype_actual) + torch_k_lora = (torch_k_lora.to(torch.float32) / k_scale).to(cache_dtype_actual) + torch_k_pe = (torch_k_pe.to(torch.float32) / k_scale).to(cache_dtype_actual) + else: + torch_k_lora = torch_k_lora + torch_k_pe = torch_k_pe + + torch_zeros = torch.zeros(((T, QH_per_KH * KH, D_lora)), dtype=dtype, device="cuda") + torch_kv_cache[slot_mapping, :, :] = torch.cat((torch_k_lora, torch_k_pe), dim=-1) + torch_kv_cache = torch_kv_cache.view(kv_cache_og_dtype) + + triton_kv_cache = kv_cache.clone() + if cache_dtype == torch.uint8: + triton_kv_cache = triton_kv_cache.view(cache_dtype_actual) + triton_q, triton_decode_q_pe, triton_k_pe, triton_kv_cache, triton_zeros = ( + fused_qk_rope_cat_and_cache_mla( + q_nope, + q_pe, + k_lora, + k_pe, + triton_kv_cache, + slot_mapping, + positions, + cos, + sin, + k_scale, + (rotate_style == RotateStyle.NEOX), + num_decode_toks_for_zeros=T, + apply_scale=(k_pe.dtype != kv_cache.dtype), + q_out=None, + decode_q_pe_out=None, + k_pe_out=None, + ) + ) + triton_kv_cache = triton_kv_cache.view(kv_cache_og_dtype) + + torch.testing.assert_close(torch_q, triton_q, atol=1e-1, rtol=1e-1) + torch.testing.assert_close( + torch_decode_q_pe, triton_decode_q_pe, atol=1e-1, rtol=1e-1 + ) + torch.testing.assert_close(torch_k_pe_og_dtype, triton_k_pe, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(torch_zeros, triton_zeros, atol=0.1, rtol=0.1) + + if cache_dtype == torch.uint8: + torch_kv_cache = torch_kv_cache.view(cache_dtype_actual).to(dtype) + triton_kv_cache = triton_kv_cache.view(cache_dtype_actual).to(dtype) + + torch.testing.assert_close( + torch_kv_cache[slot_mapping, :, :], + triton_kv_cache[slot_mapping, :, :], + atol=1e-1, + rtol=1e-1, + ) + + torch.testing.assert_close(torch_kv_cache, triton_kv_cache, atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize("T", [1, 2, 4, 2048]) +@pytest.mark.parametrize("QH_per_KH", [1, 16]) +@pytest.mark.parametrize("KH", [1, 8]) +@pytest.mark.parametrize("D", [64]) # For now, D is power of 2. D >= 16 +@pytest.mark.parametrize("num_kv_cahce_tokens", [16384]) +@pytest.mark.parametrize("rotate_style", [RotateStyle.GPTJ, RotateStyle.NEOX]) +@pytest.mark.parametrize("reuse_freqs_front_part", [False, True]) +@pytest.mark.parametrize("cache_dtype", [torch.bfloat16, torch.uint8]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("cache_flash", [False, True]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("x_size", [8]) +@pytest.mark.parametrize("offs", [False, True]) +def test_fused_qk_rope_reshape_and_cache( + T: int, + QH_per_KH: int, + KH: int, + D: int, + num_kv_cahce_tokens: int, + rotate_style: int, + reuse_freqs_front_part: bool, + block_size: int, + x_size: int, + cache_flash: bool, + cache_dtype: bool, + offs: bool, + dtype: torch.dtype, +): + pos = True + q, k, _, _, freqs, positions, offsets, cos, sin = generate_rope_inputs( + 1, + T, + KH, + QH_per_KH, + D, + cached=True, + reuse_freqs_front_part=reuse_freqs_front_part, + nope=False, + pos=pos, + offs=offs, + two_inputs=True, + layout="thd", + dtype=dtype, + ) + v = torch.randn_like(k) + + if cache_dtype == torch.uint8: + if arch_info.get_arch() in ["gfx950"]: + cache_dtype_actual = torch.float8_e4m3fn + else: + cache_dtype_actual = torch.float8_e4m3fnuz + pytest.skip("Skipping FP8 dtype cases non-gfx950") + + if cache_flash: + key_cache = torch.zeros( + (num_kv_cahce_tokens, block_size, KH, D), dtype=cache_dtype, device="cuda" + ) + value_cache = torch.zeros( + (num_kv_cahce_tokens, block_size, KH, D), dtype=cache_dtype, device="cuda" + ) + else: + key_cache = torch.zeros( + (num_kv_cahce_tokens, KH, D // x_size, block_size, x_size), + dtype=cache_dtype, + device="cuda", + ) + value_cache = torch.zeros( + (num_kv_cahce_tokens, KH, D, block_size), dtype=cache_dtype, device="cuda" + ) + if cache_dtype == torch.uint8: + k_scale = torch.randn( + [ + 1, + ], + dtype=torch.float32, + device="cuda", + )[0] + v_scale = torch.randn( + [ + 1, + ], + dtype=torch.float32, + device="cuda", + )[0] + else: + k_scale = torch.ones( + [ + 1, + ], + dtype=torch.float32, + device="cuda", + )[0] + v_scale = torch.ones( + [ + 1, + ], + dtype=torch.float32, + device="cuda", + )[0] + slot_mapping = torch.randperm(T, device="cuda") + key_cache_og_dtype = key_cache.dtype + value_cache_og_dtype = value_cache.dtype + + ref_freqs = ( + freqs[positions if offsets is None else torch.add(positions, offsets)].squeeze( + -2 + ) + if pos + else freqs + ) + + torch_q = ref_rope_sbhd_fwd( + q.unsqueeze(0), + ref_freqs, + rotate_style=rotate_style, + reuse_freqs_front_part=reuse_freqs_front_part, + nope_first=False, + ).squeeze(0) + torch_k = ref_rope_sbhd_fwd( + k.unsqueeze(0), + ref_freqs, + rotate_style=rotate_style, + reuse_freqs_front_part=reuse_freqs_front_part, + nope_first=False, + ).squeeze(0) + + torch_key_cache = key_cache.clone() + torch_value_cache = value_cache.clone() + slot_t = slot_mapping // block_size + slot_b = slot_mapping % block_size + torch_k_og_dtype = torch_k.clone() + if cache_dtype == torch.uint8: + torch_key_cache = torch_key_cache.view(cache_dtype_actual) + torch_value_cache = torch_value_cache.view(cache_dtype_actual) + torch_k = (torch_k.to(torch.float32) / k_scale).to(cache_dtype_actual) + torch_v = (v.to(torch.float32) / v_scale).to(cache_dtype_actual) + else: + torch_v = v + torch_zeros = torch.zeros_like(q) + if cache_flash: + torch_key_cache[slot_t, slot_b] = torch_k + torch_value_cache[slot_t, slot_b] = torch_v + else: + torch_key_cache[slot_t, :, :, slot_b, :] = torch_k.reshape( + T, KH, D // x_size, x_size + ) + torch_value_cache[slot_t, :, :, slot_b] = torch_v + torch_key_cache = torch_key_cache.view(key_cache_og_dtype) + torch_value_cache = torch_value_cache.view(value_cache_og_dtype) + + triton_key_cache = key_cache.clone() + triton_value_cache = value_cache.clone() + if cache_dtype == torch.uint8: + triton_key_cache = triton_key_cache.view(cache_dtype_actual) + triton_value_cache = triton_value_cache.view(cache_dtype_actual) + triton_q, triton_k, triton_key_cache, triton_value_cache, triton_zeros = ( + fused_qk_rope_reshape_and_cache( + q, + k, + v, + triton_key_cache, + triton_value_cache, + slot_mapping, + positions, + cos, + sin, + k_scale, + v_scale, + (rotate_style == RotateStyle.NEOX), + flash_layout=cache_flash, + apply_scale=(cache_dtype != torch.bfloat16), + offs=offsets, + q_out=q, + k_out=k, + ) + ) + triton_key_cache = triton_key_cache.view(key_cache_og_dtype) + triton_value_cache = triton_value_cache.view(value_cache_og_dtype) + + torch.testing.assert_close(torch_q, triton_q, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(torch_k_og_dtype, triton_k, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(torch_zeros, triton_zeros, atol=0.1, rtol=0.1) + + if cache_dtype == torch.uint8: + torch_key_cache = torch_key_cache.view(cache_dtype_actual).to(dtype) + triton_key_cache = triton_key_cache.view(cache_dtype_actual).to(dtype) + torch_value_cache = torch_value_cache.view(cache_dtype_actual).to(dtype) + triton_value_cache = triton_value_cache.view(cache_dtype_actual).to(dtype) + + if cache_flash: + torch.testing.assert_close( + torch_key_cache[slot_t, slot_b], + triton_key_cache[slot_t, slot_b], + atol=1e-1, + rtol=1e-1, + equal_nan=not ( + arch_info.get_arch() in ["gfx950"] + ), # TODO: investigate nan elements for non-gfx950 arch + ) + torch.testing.assert_close( + torch_value_cache[slot_t, slot_b], + triton_value_cache[slot_t, slot_b], + atol=1e-1, + rtol=1e-1, + equal_nan=not (arch_info.get_arch() in ["gfx950"]), + ) + else: + torch.testing.assert_close( + torch_key_cache[slot_t, :, :, slot_b, :], + triton_key_cache[slot_t, :, :, slot_b, :], + atol=1e-1, + rtol=1e-1, + equal_nan=not (arch_info.get_arch() in ["gfx950"]), + ) + torch.testing.assert_close( + torch_value_cache[slot_t, :, :, slot_b], + triton_value_cache[slot_t, :, :, slot_b], + atol=1e-1, + rtol=1e-1, + equal_nan=not (arch_info.get_arch() in ["gfx950"]), + ) + + torch.testing.assert_close( + torch_key_cache, + triton_key_cache, + atol=1e-1, + rtol=1e-1, + equal_nan=not (arch_info.get_arch() in ["gfx950"]), + ) + torch.testing.assert_close( + torch_value_cache, + triton_value_cache, + atol=1e-1, + rtol=1e-1, + equal_nan=not (arch_info.get_arch() in ["gfx950"]), + ) + + @pytest.mark.parametrize("T", [1, 2, 4, 128]) @pytest.mark.parametrize("QH_per_KH", [1, 4, 16]) @pytest.mark.parametrize("KH", [1, 8]) diff --git a/op_tests/triton_tests/test_fused_qkv_split_qk_rope.py b/op_tests/triton_tests/test_fused_qkv_split_qk_rope.py new file mode 100644 index 0000000000..6377123e2c --- /dev/null +++ b/op_tests/triton_tests/test_fused_qkv_split_qk_rope.py @@ -0,0 +1,117 @@ +import torch +import pytest +from aiter.ops.triton.fused_qkv_split_qk_rope import fused_qkv_split_qk_rope +from op_tests.triton_tests.test_fused_qk_concat import generate_rope_cached_freqs +from op_tests.test_rope import ref_rope_sbhd_fwd, RotateStyle + + +def generate_qkv_inputs( + B: int, QH_PER_KH: int, KH: int, D: int, nope: bool, nope_first: bool, dtype +): + qkv = torch.randn( + (B, (QH_PER_KH * KH + 2 * KH) * (D * (2 if nope else 1))), + dtype=dtype, + device="cuda", + ) + return qkv + + +def run_torch( + qkv, + QH_PER_KH, + KH, + D, + ref_freqs, + reuse_freqs_front_part, + nope, + nope_first, + rotate_style, +): + q_size = QH_PER_KH * KH * D + kv_size = KH * D + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + q = q.view(-1, QH_PER_KH * KH, D).contiguous() + k = k.view(-1, KH, D).contiguous() + v = v.view(-1, KH, D).contiguous() + + q = ref_rope_sbhd_fwd( + q, + ref_freqs, + rotate_style=rotate_style, + reuse_freqs_front_part=reuse_freqs_front_part, + nope_first=nope_first, + ) + k = ref_rope_sbhd_fwd( + k, + ref_freqs, + rotate_style=rotate_style, + reuse_freqs_front_part=reuse_freqs_front_part, + nope_first=nope_first, + ) + + return q, k, v + + +# @pytest.mark.parametrize("B", [32]) +# @pytest.mark.parametrize("QH_PER_KH", [8]) +# @pytest.mark.parametrize("KH", [8]) +# @pytest.mark.parametrize("D", [64]) +@pytest.mark.parametrize("B", [1, 4, 8, 16, 32]) +@pytest.mark.parametrize("QH_PER_KH", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("KH", [1, 4]) +@pytest.mark.parametrize("D", [64, 128]) +@pytest.mark.parametrize("rotate_style", [RotateStyle.GPTJ, RotateStyle.NEOX]) +@pytest.mark.parametrize("max_embed_positions", [131072]) +@pytest.mark.parametrize( + "nope, nope_first", [(False, False), (True, False), (True, True)] +) +@pytest.mark.parametrize("reuse_freqs_front_part", [False, True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_fused_qkv_split_qk_rope( + B: int, + QH_PER_KH: int, + KH: int, + D: int, + rotate_style: int, + max_embed_positions: int, + nope: bool, + nope_first: bool, + reuse_freqs_front_part: bool, + dtype: torch.dtype, +): + + qkv = generate_qkv_inputs(B, QH_PER_KH, KH, D, nope, nope_first, dtype) + + pos, freqs, cos, sin = generate_rope_cached_freqs( + B, max_embed_positions, (D // 2) if reuse_freqs_front_part else D, dtype + ) + ref_freqs = freqs[pos].squeeze(-2) + + q_triton, k_triton, v_triton = fused_qkv_split_qk_rope( + qkv, + cos, + sin, + pos, + QH_PER_KH * KH, + KH, + (D * (2 if nope else 1)), + is_neox=(rotate_style == RotateStyle.NEOX), + offsets=None, + reuse_freqs_front_part=reuse_freqs_front_part, + nope_first=nope_first, + ) + q_torch, k_torch, v_torch = run_torch( + qkv, + QH_PER_KH, + KH, + (D * (2 if nope else 1)), + ref_freqs, + reuse_freqs_front_part, + nope, + nope_first, + rotate_style, + ) + + torch.testing.assert_close(q_torch, q_triton) + torch.testing.assert_close(k_torch, k_triton) + torch.testing.assert_close(v_torch, v_triton) diff --git a/op_tests/triton_tests/test_gemm_a16w16.py b/op_tests/triton_tests/test_gemm_a16w16.py index 4d6820ce2f..66f64b6119 100644 --- a/op_tests/triton_tests/test_gemm_a16w16.py +++ b/op_tests/triton_tests/test_gemm_a16w16.py @@ -9,7 +9,7 @@ from op_tests.triton_tests.utils.types import str_to_torch_dtype -def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True): +def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True, bias=False): if isinstance(dtype, str): dtype = str_to_torch_dtype[dtype] @@ -24,6 +24,10 @@ def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True): else: weight = torch.randn((N, K), dtype=dtype, device="cuda") + bias_tensor = None + if bias: + bias_tensor = torch.empty((N), dtype=dtype, device="cuda") + y = None if output: y = torch.empty((M, N), dtype=dtype, device="cuda") @@ -31,7 +35,7 @@ def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN", output=True): else: out_dtype = dtype - return x, weight, out_dtype, y + return x, weight, bias_tensor, out_dtype, y def get_x_vals(): @@ -76,7 +80,7 @@ def get_x_vals(): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("output", [True, False]) def test_gemm_a16_w16_activation(M: int, N: int, K: int, dtype, output, activation): - x, w, out_dtype, y = generate_gemm_a16w16_inputs( + x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( M, N, K, @@ -98,6 +102,7 @@ def test_gemm_a16_w16_activation(M: int, N: int, K: int, dtype, output, activati triton_out = gemm_a16w16( x, w, + None, out_dtype, y, activation=activation, @@ -106,6 +111,7 @@ def test_gemm_a16_w16_activation(M: int, N: int, K: int, dtype, output, activati triton_out = gemm_a16w16( x, w, + None, out_dtype, activation=activation, ) @@ -119,14 +125,16 @@ def test_gemm_a16_w16_activation(M: int, N: int, K: int, dtype, output, activati def test_gemm_a16_w16(M: int, N: int, K: int, dtype, output): torch.cuda.empty_cache() # Helps avoid hangs in large tests - x, w, out_dtype, y = generate_gemm_a16w16_inputs(M, N, K, dtype, output=output) + x, w, bias, out_dtype, y = generate_gemm_a16w16_inputs( + M, N, K, dtype, output=output, bias=True + ) - torch_out = F.linear(x, w, bias=None) + torch_out = F.linear(x, w, bias=bias) if output: - triton_out = gemm_a16w16(x, w, out_dtype, y) + triton_out = gemm_a16w16(x, w, bias, out_dtype, y) else: - triton_out = gemm_a16w16(x, w, out_dtype) + triton_out = gemm_a16w16(x, w, bias, out_dtype) torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) @@ -138,16 +146,16 @@ def test_gemm_a16_w16(M: int, N: int, K: int, dtype, output): def test_gemm_a16_w16_layout(M: int, N: int, K: int, dtype, layout, output): torch.cuda.empty_cache() # Helps avoid hangs in large tests - x, w, out_dtype, y = generate_gemm_a16w16_inputs( + x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( M, N, K, dtype, layout=layout, output=output ) torch_out = F.linear(x, w, bias=None) if output: - triton_out = gemm_a16w16(x, w, out_dtype, y) + triton_out = gemm_a16w16(x, w, None, out_dtype, y) else: - triton_out = gemm_a16w16(x, w, out_dtype) + triton_out = gemm_a16w16(x, w, None, out_dtype) torch.testing.assert_close(triton_out, torch_out, atol=1e-1, rtol=1e-1) @@ -158,7 +166,7 @@ def test_gemm_a16_w16_layout(M: int, N: int, K: int, dtype, layout, output): def test_gemm_a16_w16_atomic(M: int, N: int, K: int, dtype, output): torch.cuda.empty_cache() # Helps avoid hangs in large tests - x, w, out_dtype, y = generate_gemm_a16w16_inputs(M, N, K, dtype, output=output) + x, w, _, out_dtype, y = generate_gemm_a16w16_inputs(M, N, K, dtype, output=output) torch_out = F.linear(x, w, bias=None) @@ -179,7 +187,7 @@ def test_gemm_a16_w16_atomic(M: int, N: int, K: int, dtype, output): def test_gemm_a16_w16_atomic_layout(M: int, N: int, K: int, dtype, layout, output): torch.cuda.empty_cache() # Helps avoid hangs in large tests - x, w, out_dtype, y = generate_gemm_a16w16_inputs( + x, w, _, out_dtype, y = generate_gemm_a16w16_inputs( M, N, K, dtype, layout=layout, output=output ) diff --git a/op_tests/triton_tests/test_gemm_afp4wfp4.py b/op_tests/triton_tests/test_gemm_afp4wfp4.py index e517144656..1ae27efbe8 100644 --- a/op_tests/triton_tests/test_gemm_afp4wfp4.py +++ b/op_tests/triton_tests/test_gemm_afp4wfp4.py @@ -6,13 +6,11 @@ from aiter.ops.triton.gemm_afp4wfp4 import ( gemm_afp4wfp4, gemm_afp4wfp4_preshuffled_scales, + gemm_afp4wfp4_preshuffled_weight_scales, ) import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton.utils.types import str_to_torch_dtype - -TRITON_HIP_PRESHUFFLE_SCALES = ( - os.environ.get("TRITON_HIP_PRESHUFFLE_SCALES", "0") == "1" -) +from aiter.ops.shuffle import shuffle_weight def shuffle_scales(scales: torch.Tensor): @@ -28,7 +26,21 @@ def shuffle_scales(scales: torch.Tensor): SCALE_GROUP_SIZE = 32 -def generate_gemm_afp4wfp4_inputs(M, N, K, dtype, layout="TN", output=True): +def generate_gemm_afp4wfp4_inputs( + M, + N, + K, + dtype, + layout="TN", + output=True, + shuffle_weight_fg=False, + shuffle_scales_fg=False, +): + if shuffle_weight_fg: + assert ( + shuffle_scales_fg + ), "weight shuffling is only supported with scale shuffling" + torch.manual_seed(5) if isinstance(dtype, str): dtype = str_to_torch_dtype[dtype] @@ -55,7 +67,7 @@ def generate_gemm_afp4wfp4_inputs(M, N, K, dtype, layout="TN", output=True): w = w_low | w_high << 4 # Scale of 1.0 in e8m0, bias 127. - if M >= 32 and TRITON_HIP_PRESHUFFLE_SCALES: + if M >= 32 and shuffle_scales_fg: M_pad = (M + 255) // 256 * 256 else: M_pad = M @@ -67,19 +79,31 @@ def generate_gemm_afp4wfp4_inputs(M, N, K, dtype, layout="TN", output=True): ) x_scales = x_scales.T w_scales = w_scales.T - if TRITON_HIP_PRESHUFFLE_SCALES: + if shuffle_scales_fg: if M >= 32: x_scales_shuffled = shuffle_scales(x_scales) else: - x_scales_shuffled = x_scales + x_scales_shuffled = x_scales.contiguous() w_scales_shuffled = shuffle_scales(w_scales) else: x_scales_shuffled = x_scales w_scales_shuffled = w_scales + if shuffle_weight_fg: + use_int4 = False + weight_shuffle_layout = (16, 16) + w_shuffed = shuffle_weight( + w, layout=weight_shuffle_layout, use_int4=use_int4 + ).reshape( + w.shape[0] // weight_shuffle_layout[0], + w.shape[1] * weight_shuffle_layout[0], + ) + else: + w_shuffed = w + y = None if output: - y = torch.empty((M, N), dtype=dtype, device="cuda") + y = torch.empty((M, N), dtype=dtype).cuda() out_dtype = (None,) else: out_dtype = dtype @@ -87,6 +111,7 @@ def generate_gemm_afp4wfp4_inputs(M, N, K, dtype, layout="TN", output=True): return ( x, w, + w_shuffed, x_scales[:M], w_scales, x_scales_shuffled, @@ -133,6 +158,10 @@ def get_x_vals(): x_vals += [(16, 16384, 3328 * 2), (128, 16384, 3328 * 2)] x_vals += [(256, 3584, 2112)] x_vals += [(7, 4608, 7168), (7, 7168, 2304)] + x_vals += [(v, 106496, 16384) for v in [1, 8, 16, 32, 64, 128, 256]] + x_vals += [(v, 16384, 53248) for v in [1, 8, 16, 32, 64, 128, 256]] + x_vals += [(v, 18432, 16384) for v in [1, 8, 16, 32, 64, 128, 256]] + x_vals += [(v, 16384, 16384) for v in [1, 8, 16, 32, 64, 128, 256]] x_vals += [(1, 1, 32)] # minimal case return x_vals @@ -188,48 +217,81 @@ def run_torch(x, w, x_scales, w_scales, dtype): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("layout", ["TN", "TT", "NN", "NT"]) @pytest.mark.parametrize("output", [True, False]) -def test_gemm_afp4_wfp4(M: int, N: int, K: int, dtype, layout, output): +@pytest.mark.parametrize( + "shuffle_scales_fg, shuffle_weight_fg", + [(False, False), (True, False), (True, True)], +) +def test_gemm_afp4_wfp4( + M: int, N: int, K: int, dtype, layout, output, shuffle_scales_fg, shuffle_weight_fg +): if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") - torch.cuda.empty_cache() # Helps avoid hangs in large tests + if shuffle_weight_fg and not shuffle_scales_fg: + pytest.skip("Preshuffling weight without preshuffled scales is not supported") + + if shuffle_weight_fg or shuffle_scales_fg: + if shuffle_scales_fg and not shuffle_weight_fg and M < 32: + pytest.skip("Minimal tile size for preshuffled scales is 32x32x256") - if TRITON_HIP_PRESHUFFLE_SCALES: if N % 32 > 0: pytest.skip( - f"N = {N} is not divisible by 32, skip this test for preshuffled scales tests" + f"N = {N} is not divisible by 32, skip this test for preshuffled weight/scales tests" ) elif K % 256 > 0: pytest.skip( - f"K = {K} is not divisible by 256, skip this test for preshuffled scales tests" + f"K = {K} is not divisible by 256, skip this test for preshuffled weight/scales tests" ) ( x, w, + w_triton, x_scales, w_scales, x_scales_triton, w_scales_triton, out_dtype, y, - ) = generate_gemm_afp4wfp4_inputs(M, N, K, dtype, layout=layout, output=output) + ) = generate_gemm_afp4wfp4_inputs( + M, + N, + K, + dtype, + layout=layout, + output=output, + shuffle_scales_fg=shuffle_scales_fg, + shuffle_weight_fg=shuffle_weight_fg, + ) torch_out = run_torch(x, w, x_scales, w_scales, dtype).to(dtype) - if TRITON_HIP_PRESHUFFLE_SCALES: + if shuffle_scales_fg and shuffle_weight_fg: + if output: + triton_out = gemm_afp4wfp4_preshuffled_weight_scales( + x, w_triton, x_scales_triton, w_scales_triton, dtype, y + ) + else: + triton_out = gemm_afp4wfp4_preshuffled_weight_scales( + x, w_triton, x_scales_triton, w_scales_triton, dtype + ) + elif shuffle_scales_fg and not shuffle_weight_fg: if output: triton_out = gemm_afp4wfp4_preshuffled_scales( - x, w, x_scales_triton, w_scales_triton, dtype, y + x, w_triton, x_scales_triton, w_scales_triton, dtype, y ) else: triton_out = gemm_afp4wfp4_preshuffled_scales( - x, w, x_scales_triton, w_scales_triton, dtype + x, w_triton, x_scales_triton, w_scales_triton, dtype ) else: if output: - triton_out = gemm_afp4wfp4(x, w, x_scales_triton, w_scales_triton, dtype, y) + triton_out = gemm_afp4wfp4( + x, w_triton, x_scales_triton, w_scales_triton, dtype, y + ) else: - triton_out = gemm_afp4wfp4(x, w, x_scales_triton, w_scales_triton, dtype) + triton_out = gemm_afp4wfp4( + x, w_triton, x_scales_triton, w_scales_triton, dtype + ) torch.testing.assert_close(torch_out, triton_out) diff --git a/op_tests/triton_tests/test_mha.py b/op_tests/triton_tests/test_mha.py index da10db4313..8bae346de0 100644 --- a/op_tests/triton_tests/test_mha.py +++ b/op_tests/triton_tests/test_mha.py @@ -456,9 +456,9 @@ def test_mha_varlen( ) if FP8: - torch.testing.assert_close( - triton_out, torch_out.to(triton_out.dtype), atol=0.25, rtol=10 - ) # Lower tolerance for FP8 + fp8_assert_close( + triton_out, torch_out.to(torch_out.dtype), atol=ATOL_fp8, rtol=RTOL_fp8 + ) else: torch.testing.assert_close( triton_out, torch_out.to(triton_out.dtype), atol=1e-1, rtol=1e-1 diff --git a/op_tests/triton_tests/test_unified_attention.py b/op_tests/triton_tests/test_unified_attention.py new file mode 100644 index 0000000000..6aa29deabe --- /dev/null +++ b/op_tests/triton_tests/test_unified_attention.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import pytest +import torch + +from aiter.ops.triton.unified_attention import unified_attention +from aiter.ops.triton.utils.types import e4m3_dtype + +NUM_HEADS = [(4, 4), (8, 2), (16, 2)] +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16, 64] + +DTYPES = [torch.float16, torch.bfloat16] +QDTYPES = [None, e4m3_dtype] +# one value large enough to test overflow in index calculation. +# one value small enough to test the schema op check +NUM_BLOCKS = [32768, 2048] + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: list[int], + kv_lens: list[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, + soft_cap: Optional[float] = None, + sinks: Optional[torch.Tensor] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs: list[torch.Tensor] = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx : start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len, device=q.device) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) + mask |= sliding_window_mask + if soft_cap is not None and soft_cap > 0: + attn = soft_cap * torch.tanh(attn / soft_cap) + attn.masked_fill_(mask, float("-inf")) + if sinks is not None: + s_aux = sinks[:, None, None].repeat_interleave(attn.shape[-2], dim=-2) + attn = torch.cat((attn, s_aux), dim=-1) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + if sinks is not None: + attn = attn[..., :-1] + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize( + "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] +) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("q_dtype", QDTYPES) +@torch.inference_mode() +def test_triton_unified_attn( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], + num_blocks: int, + q_dtype: Optional[torch.dtype], +) -> None: + if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32: + pytest.skip("block size must be at least 32 for fp8") + + torch.manual_seed(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) + scale = head_size**-0.5 + + query = torch.randn( + sum(query_lens), num_query_heads, head_size, dtype=dtype, device="cuda" + ) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device="cuda" + ) + value_cache = torch.randn_like(key_cache) + cu_query_lens = torch.tensor( + [0] + query_lens, dtype=torch.int32, device="cuda" + ).cumsum(dim=0, dtype=torch.int32) + kv_lens = torch.tensor(kv_lens, dtype=torch.int32, device="cuda") + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint( + 0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32, + device="cuda", + ) + sinks = torch.randn(num_query_heads, dtype=torch.bfloat16, device="cuda") + output = torch.empty_like(query) + + maybe_quantized_query = query + maybe_quantized_key_cache = key_cache + maybe_quantized_value_cache = value_cache + q_descale = None + k_descale = None + v_descale = None + if q_dtype is not None: + # QKV are drawn from N(0, 1): no need for a fp8 scaling factor + maybe_quantized_query = query.to(q_dtype) + maybe_quantized_key_cache = key_cache.to(q_dtype) + maybe_quantized_value_cache = value_cache.to(q_dtype) + + scale_shape = (num_seqs, num_kv_heads) + q_descale = None # Not yet supported + k_descale = torch.rand(scale_shape, dtype=torch.float32, device="cuda") + v_descale = torch.rand(scale_shape, dtype=torch.float32, device="cuda") + + unified_attention( + q=maybe_quantized_query, + k=maybe_quantized_key_cache, + v=maybe_quantized_value_cache, + out=output, + cu_seqlens_q=cu_query_lens, + seqused_k=kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + softcap=soft_cap if soft_cap is not None else 0, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + sinks=sinks, + ) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + sinks=sinks, + ) + atol, rtol = 1.5e-2, 1e-2 + if q_dtype is not None: + atol, rtol = 1.5e-1, 1.5e-1 + torch.testing.assert_close( + output, ref_output, atol=atol, rtol=rtol + ), f"{torch.max(torch.abs(output - ref_output))}"