diff --git a/aiter/ops/triton/_triton_kernels/batched_gemm_a8w8.py b/aiter/ops/triton/_triton_kernels/batched_gemm_a8w8.py index afbb244ece..8abfc4fcd1 100644 --- a/aiter/ops/triton/_triton_kernels/batched_gemm_a8w8.py +++ b/aiter/ops/triton/_triton_kernels/batched_gemm_a8w8.py @@ -5,9 +5,23 @@ import json import triton import triton.language as tl -from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH +from ..utils._triton.kernel_repr import make_kernel_repr + + +_batched_gemm_a8w8_repr = make_kernel_repr( + "_batched_gemm_a8w8_kernel", + [ + "HAS_BIAS", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "GRID_MN", + ], +) @triton.heuristics( @@ -17,7 +31,7 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit(repr=_batched_gemm_a8w8_repr) def _batched_gemm_a8w8_kernel( # Pointers to matrices a_ptr, diff --git a/aiter/ops/triton/_triton_kernels/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py b/aiter/ops/triton/_triton_kernels/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py index 27bd419c3b..30ed631dbf 100644 --- a/aiter/ops/triton/_triton_kernels/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py +++ b/aiter/ops/triton/_triton_kernels/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py @@ -8,6 +8,22 @@ import triton.language as tl from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH +from ..utils._triton.kernel_repr import make_kernel_repr + + +_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_repr = make_kernel_repr( + "_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_kernel", + [ + "HAS_BIAS", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "cache_modifier", + "GRID_MN", + ], +) @triton.heuristics( @@ -17,7 +33,9 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit( + repr=_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_repr +) def _batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_kernel( # Pointers to matrices a_ptr, diff --git a/aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4.py b/aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4.py index add6a2d222..48e1a730b2 100755 --- a/aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4.py +++ b/aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4.py @@ -9,6 +9,33 @@ from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH +from ..utils._triton.kernel_repr import make_kernel_repr + + +_batched_gemm_afp4_wfp4_repr = make_kernel_repr( + "_batched_gemm_afp4_wfp4_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "EVEN_K", + "GRID_MN", + "cache_modifier", + ], +) + +_batched_gemm_afp4_wfp4_reduce_repr = make_kernel_repr( + "_batched_gemm_afp4_wfp4_reduce_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "ACTUAL_KSPLIT", + "MAX_KSPLIT", + ], +) @triton.heuristics( @@ -20,7 +47,7 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit(repr=_batched_gemm_afp4_wfp4_repr) def _batched_gemm_afp4_wfp4_kernel( a_ptr, b_ptr, @@ -210,7 +237,7 @@ def _batched_gemm_afp4_wfp4_kernel( tl.store(c_ptrs, c, mask=c_mask) -@triton.jit +@triton.jit(repr=_batched_gemm_afp4_wfp4_reduce_repr) def _batched_gemm_afp4_wfp4_reduce_kernel( c_in_ptr, c_out_ptr, diff --git a/aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4_pre_quant.py index d3fe88a1cd..d94b3ae958 100755 --- a/aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4_pre_quant.py @@ -10,6 +10,34 @@ from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH from .quant import _mxfp4_quant_op +from ..utils._triton.kernel_repr import make_kernel_repr + + +_batched_gemm_afp4_wfp4_pre_quant_repr = make_kernel_repr( + "_batched_gemm_afp4_wfp4_pre_quant_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "EVEN_K", + "GRID_MN", + "cache_modifier", + ], +) + + +_batched_gemm_afp4_wfp4_pre_quant_reduce_repr = make_kernel_repr( + "_batched_gemm_afp4_wfp4_pre_quant_reduce_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "ACTUAL_KSPLIT", + "MAX_KSPLIT", + ], +) @triton.heuristics( @@ -21,7 +49,7 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit(repr=_batched_gemm_afp4_wfp4_pre_quant_repr) def _batched_gemm_afp4_wfp4_pre_quant_kernel( a_ptr, b_ptr, @@ -54,7 +82,8 @@ def _batched_gemm_afp4_wfp4_pre_quant_kernel( GRID_MN: tl.constexpr, cache_modifier: tl.constexpr, ): - """Kernel for computing the matmul C = A x B. + """ + Kernel for computing the matmul C = A x B. A and B inputs are in the microscale fp4 (mxfp4) format. A_scales and B_scales are in e8m0 format. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -184,7 +213,7 @@ def _batched_gemm_afp4_wfp4_pre_quant_kernel( tl.store(c_ptrs, c, mask=c_mask) -@triton.jit +@triton.jit(repr=_batched_gemm_afp4_wfp4_pre_quant_reduce_repr) def _batched_gemm_afp4_wfp4_pre_quant_reduce_kernel( c_in_ptr, c_out_ptr, diff --git a/aiter/ops/triton/_triton_kernels/batched_gemm_bf16.py b/aiter/ops/triton/_triton_kernels/batched_gemm_bf16.py index 11329f15d1..178202e950 100644 --- a/aiter/ops/triton/_triton_kernels/batched_gemm_bf16.py +++ b/aiter/ops/triton/_triton_kernels/batched_gemm_bf16.py @@ -5,9 +5,23 @@ import json import triton import triton.language as tl -from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH +from ..utils._triton.kernel_repr import make_kernel_repr + + +_batched_gemm_bf16_repr = make_kernel_repr( + "_batched_gemm_bf16_kernel", + [ + "HAS_BIAS", + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "EVEN_K", + "GRID_MN", + ], +) @triton.heuristics( @@ -17,7 +31,7 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit(repr=_batched_gemm_bf16_repr) def _batched_gemm_bf16_kernel( # Pointers to matrices a_ptr, diff --git a/aiter/ops/triton/batched_gemm_a8w8.py b/aiter/ops/triton/batched_gemm_a8w8.py index 5c4f39103e..1ac5e6fe3a 100644 --- a/aiter/ops/triton/batched_gemm_a8w8.py +++ b/aiter/ops/triton/batched_gemm_a8w8.py @@ -4,9 +4,6 @@ from typing import Optional import torch import triton -import triton.language as tl -import aiter.ops.triton.utils._triton.arch_info as arch_info -from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH from aiter.ops.triton._triton_kernels.batched_gemm_a8w8 import ( _batched_gemm_a8w8_kernel, _get_config, @@ -28,22 +25,23 @@ def batched_gemm_a8w8( config: Optional[dict] = None, ): """ - Computes the matmul YQ[i] = XQ[i] x WQ[i]T and applies a conversion scale for every i in a given batch. - Optionally, adds a bias to each result. - - The conversion scale for each matmul is received in the form of two 1D tensors that are multiplied to form a - 2D one before being applied. - - Key parameters: - - XQ: Batch tensor XQ with shape (B, M, K). - - WQ: Batch tensor WQ with shape (B, N, K). - - X_scale: First scale batch tensor with shape (B, M, 1). - - W_scale: Second scale batch tensor with shape (B, 1, N). - - Bias: Bias batch tensor with shape (B, 1, N). - - YQ: Output Matrix Y with shape (B, M, N). If this is none, then it's created by this API and returned as output + Computes batched 8 bit matrix multiplication Y[i] = X[i] @ W[i]^T with per-batch scaling. + Each batch element is independently scaled back to higher precision. + + Args: + XQ (torch.Tensor): INT8 input batch with shape (B, M, K). + WQ (torch.Tensor): INT8 weight batch with shape (B, N, K), internally transposed. + x_scale (torch.Tensor): Scale for XQ with shape (B, M, 1). + w_scale (torch.Tensor): Scale for WQ with shape (B, 1, N). + bias (Optional[torch.Tensor]): Bias batch with shape (B, 1, N). + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + splitK (Optional[int]): Not supported. Must be None. + YQ (Optional[torch.Tensor]): Pre-allocated output tensor with shape (B, M, N). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M). Returns: - - YQ: The output batch tensor with shape (B, M, N). + torch.Tensor: Output batch with shape (B, M, N). """ _LOGGER.info( f"BATCHED_GEMM_A8W8: x={tuple(XQ.shape)} w={tuple(WQ.shape)} x_scale={tuple(x_scale.shape)} w_scale={tuple(w_scale.shape)}" diff --git a/aiter/ops/triton/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py b/aiter/ops/triton/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py index 64d2e73f99..9fd8e505b8 100644 --- a/aiter/ops/triton/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py +++ b/aiter/ops/triton/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py @@ -4,9 +4,6 @@ from typing import Optional import torch import triton -import triton.language as tl -import aiter.ops.triton.utils._triton.arch_info as arch_info -from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH from aiter.ops.triton._triton_kernels.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( _batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_kernel, _get_config, @@ -27,21 +24,26 @@ def batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( config: Optional[dict] = None, ): """ - Computes the matmul YQ[i] = XQ[i] x WQ[i]T and applies a conversion scale for every i in a given batch. - Optionally, adds a bias to each result. + Computes batched 8 bit matrix multiplication Y[i] = X[i] @ W[i]^T with active activation quantization. + X is quantized to INT8 during computation using per-token grouped quantization. + W is pre-quantized INT8 with per-batch-element scaling. - The conversion scale for each matmul is received in the form of two 1D tensors that are multiplied to form a - 2D one before being applied. - - Key parameters: - - XQ: Batch tensor XQ with shape (B, M, K) if transpose_bm_in == False else (M, B, K). - - WQ: Batch tensor WQ with shape (B, N, K). - - W_scale: Second scale batch tensor with shape (1, ). - - Bias: Bias batch tensor with shape (B, 1, N). - - YQ: Output Matrix Y with shape (B, M, N). If this is none, then it's created by this API and returned as output + Args: + X (torch.Tensor): Higher precision input batch with shape (B, M, K) or (M, B, K) if transpose_bm_in=True. + Quantized to INT8 on-the-fly during GEMM. + WQ (torch.Tensor): Pre-quantized INT8 weight batch with shape (B, N, K), internally transposed. + w_scale (torch.Tensor): Per-batch scale for WQ with shape (1,). + group_size (int): Group size for per-token grouped quantization of X. Must be power of 2. + bias (Optional[torch.Tensor]): Bias batch with shape (B, 1, N). + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + splitK (Optional[int]): Not supported. Must be None. + YQ (Optional[torch.Tensor]): Pre-allocated output tensor with shape (B, M, N) or (M, B, N) if transpose_bm=True. + transpose_bm (Optional[bool]): Transpose batch and M dimensions in output. + transpose_bm_in (Optional[bool]): Transpose batch and M dimensions in input. + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M). Returns: - - YQ: The output batch tensor with shape (B, M, N) if transpose_bm == False else (M, B, N). + torch.Tensor: Output batch with shape (B, M, N) or (M, B, N) if transpose_bm=True. """ # Check constraints. diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4.py b/aiter/ops/triton/batched_gemm_afp4wfp4.py index 4c26506c64..046ee0040a 100755 --- a/aiter/ops/triton/batched_gemm_afp4wfp4.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4.py @@ -4,7 +4,6 @@ from typing import Optional import torch import triton -import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton._triton_kernels.batched_gemm_afp4wfp4 import ( _batched_gemm_afp4_wfp4_kernel, @@ -34,20 +33,22 @@ def batched_gemm_afp4wfp4( config: Optional[dict] = None, ): """ - Computes the matmul Y = X x W - X and W are e2m1 fp4 tensors. - x_scales and w_scales are e8m0 tensors. - Every 32 elements in the K dimension share one e8m0 scale. - - - Key parameters: - - X: Matrix X with shape (B, M, K). - - W: Matrix W with shape (B, N, K). - - X_scales: Matrix with shape (B, M, K // 32) - - W_scales: Matrix with shape (B, N, K // 32) + Computes batched FP4 matrix multiplication Y[i] = X[i] @ W[i]^T with FP4 activations and weights. + + Args: + x (torch.Tensor): FP4 E2M1 input batch with shape (B, M, K). + w (torch.Tensor): FP4 E2M1 weight batch with shape (B, N, K), internally transposed. + x_scales (torch.Tensor): E8M0 per-group scale for x with shape (B, M, K//32). + One scale per 32 elements in K dimension. + w_scales (torch.Tensor): E8M0 per-group scale for w with shape (B, N, K//32). + One scale per 32 elements in K dimension. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (B, M, N). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). Returns: - - Y: The output matrix with shape (B, M, N). + torch.Tensor: Output batch with shape (B, M, N). """ _LOGGER.info( f"BATCHED_GEMM_AFP4WFP4: x={tuple(x.shape)} w={tuple(w.shape)} x_scale={tuple(x.shape)} w_scale={tuple(w.shape)}" diff --git a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py index 8950a11b19..8679344856 100755 --- a/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py @@ -4,7 +4,6 @@ from typing import Optional import torch import triton -import triton.language as tl import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton._triton_kernels.batched_gemm_afp4wfp4_pre_quant import ( _batched_gemm_afp4_wfp4_pre_quant_reduce_kernel, @@ -33,19 +32,22 @@ def batched_gemm_afp4wfp4_pre_quant( config: Optional[dict] = None, ): """ - Computes the matmul Y = X x W - W is an e2m1 fp4 tensor and w_scales is an e8m0 tensor. - Every 32 elements in the K dimension share one e8m0 scale. - X gets quantized to the microscale fp4 (mxfp4) format before the GEMM. - - Key parameters: - - X: Matrix X with shape (B, M, K). - - W: Matrix W with shape (B, N, K). - - X_scales: Matrix with shape (B, M, K // 32) - - W_scales: Matrix with shape (B, N, K // 32) + Computes batched FP4 matrix multiplication Y[i] = X[i] @ W[i]^T with active activation quantization. + X is quantized to MXFP4 during computation, W is pre-quantized FP4. + + Args: + x (torch.Tensor): Higher precision input batch with shape (B, M, K) (BF16 or FP16). + Quantized to MXFP4 on-the-fly during GEMM. + w (torch.Tensor): FP4 E2M1 weight batch with shape (B, N, K), internally transposed. + w_scales (torch.Tensor): E8M0 per-group scale for w with shape (B, N, K//32). + One scale per 32 elements in K dimension. + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (B, M, N). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). Returns: - - Y: The output matrix with shape (M, N). + torch.Tensor: Output batch with shape (B, M, N). """ _LOGGER.info( f"BATCHED_GEMM_AFP4WFP_PREQUANT: x={tuple(x.shape)} w={tuple(w.shape)} w_scale={tuple(w.shape)}" @@ -58,7 +60,6 @@ def batched_gemm_afp4wfp4_pre_quant( By, _, _ = y.shape assert Bx == Bw == By Batch = Bx - w = w.transpose(1, 2) if config is None: config = _get_config(M, N, K) diff --git a/aiter/ops/triton/batched_gemm_bf16.py b/aiter/ops/triton/batched_gemm_bf16.py index 8883c5688e..6948b142ae 100644 --- a/aiter/ops/triton/batched_gemm_bf16.py +++ b/aiter/ops/triton/batched_gemm_bf16.py @@ -4,7 +4,6 @@ from typing import Optional import torch import triton -import triton.language as tl from aiter.ops.triton._triton_kernels.batched_gemm_bf16 import ( _batched_gemm_bf16_kernel, _get_config, @@ -24,16 +23,20 @@ def batched_gemm_bf16( config: Optional[dict] = None, ): """ - Computes the matmul YQ[i] = XQ[i] x WQ[i]T for every i in a given batch and optionally adds a bias to each result. + Computes batched 16 bit matrix multiplication Y[i] = X[i] @ W[i]^T with optional bias. - Key parameters: - - XQ: Batch tensor XQ with shape (B, M, K). - - WQ: Batch tensor WQ with shape (B, N, K). - - Bias: Bias batch tensor with shape (B, 1, N). - - YQ: Output Matrix Y with shape (B, M, N). If this is none, then it's created by this API and returned as output + Args: + XQ (torch.Tensor): Input batch with shape (B, M, K) (BF16 or FP16). + WQ (torch.Tensor): Weight batch with shape (B, N, K), internally transposed. + bias (Optional[torch.Tensor]): Bias batch with shape (B, 1, N). + dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). + splitK (Optional[int]): Not supported. Must be None. + YQ (Optional[torch.Tensor]): Pre-allocated output tensor with shape (B, M, N). + config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M). Returns: - - YQ: The output batch tensor with shape (B, M, N). + torch.Tensor: Output batch with shape (B, M, N). """ _LOGGER.info(f"BATCHED_GEMM_BF16: x={tuple(XQ.shape)} w={tuple(WQ.shape)}")