Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions aiter/ops/triton/_triton_kernels/batched_gemm_a8w8.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,23 @@
import json
import triton
import triton.language as tl
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.kernel_repr import make_kernel_repr


_batched_gemm_a8w8_repr = make_kernel_repr(
"_batched_gemm_a8w8_kernel",
[
"HAS_BIAS",
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"EVEN_K",
"GRID_MN",
],
)


@triton.heuristics(
Expand All @@ -17,7 +31,7 @@
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
@triton.jit(repr=_batched_gemm_a8w8_repr)
def _batched_gemm_a8w8_kernel(
# Pointers to matrices
a_ptr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
import triton.language as tl
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.kernel_repr import make_kernel_repr


_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_repr = make_kernel_repr(
"_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_kernel",
[
"HAS_BIAS",
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"EVEN_K",
"cache_modifier",
"GRID_MN",
],
)


@triton.heuristics(
Expand All @@ -17,7 +33,9 @@
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
@triton.jit(
repr=_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_repr
)
def _batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_kernel(
# Pointers to matrices
a_ptr,
Expand Down
31 changes: 29 additions & 2 deletions aiter/ops/triton/_triton_kernels/batched_gemm_afp4wfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,33 @@
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.kernel_repr import make_kernel_repr


_batched_gemm_afp4_wfp4_repr = make_kernel_repr(
"_batched_gemm_afp4_wfp4_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"NUM_KSPLIT",
"SPLITK_BLOCK_SIZE",
"EVEN_K",
"GRID_MN",
"cache_modifier",
],
)

_batched_gemm_afp4_wfp4_reduce_repr = make_kernel_repr(
"_batched_gemm_afp4_wfp4_reduce_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"ACTUAL_KSPLIT",
"MAX_KSPLIT",
],
)


@triton.heuristics(
Expand All @@ -20,7 +47,7 @@
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
@triton.jit(repr=_batched_gemm_afp4_wfp4_repr)
def _batched_gemm_afp4_wfp4_kernel(
a_ptr,
b_ptr,
Expand Down Expand Up @@ -210,7 +237,7 @@ def _batched_gemm_afp4_wfp4_kernel(
tl.store(c_ptrs, c, mask=c_mask)


@triton.jit
@triton.jit(repr=_batched_gemm_afp4_wfp4_reduce_repr)
def _batched_gemm_afp4_wfp4_reduce_kernel(
c_in_ptr,
c_out_ptr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,34 @@
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from .quant import _mxfp4_quant_op
from ..utils._triton.kernel_repr import make_kernel_repr


_batched_gemm_afp4_wfp4_pre_quant_repr = make_kernel_repr(
"_batched_gemm_afp4_wfp4_pre_quant_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"NUM_KSPLIT",
"SPLITK_BLOCK_SIZE",
"EVEN_K",
"GRID_MN",
"cache_modifier",
],
)


_batched_gemm_afp4_wfp4_pre_quant_reduce_repr = make_kernel_repr(
"_batched_gemm_afp4_wfp4_pre_quant_reduce_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"ACTUAL_KSPLIT",
"MAX_KSPLIT",
],
)


@triton.heuristics(
Expand All @@ -21,7 +49,7 @@
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
@triton.jit(repr=_batched_gemm_afp4_wfp4_pre_quant_repr)
def _batched_gemm_afp4_wfp4_pre_quant_kernel(
a_ptr,
b_ptr,
Expand Down Expand Up @@ -54,7 +82,8 @@ def _batched_gemm_afp4_wfp4_pre_quant_kernel(
GRID_MN: tl.constexpr,
cache_modifier: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
"""
Kernel for computing the matmul C = A x B.
A and B inputs are in the microscale fp4 (mxfp4) format.
A_scales and B_scales are in e8m0 format.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
Expand Down Expand Up @@ -184,7 +213,7 @@ def _batched_gemm_afp4_wfp4_pre_quant_kernel(
tl.store(c_ptrs, c, mask=c_mask)


@triton.jit
@triton.jit(repr=_batched_gemm_afp4_wfp4_pre_quant_reduce_repr)
def _batched_gemm_afp4_wfp4_pre_quant_reduce_kernel(
c_in_ptr,
c_out_ptr,
Expand Down
18 changes: 16 additions & 2 deletions aiter/ops/triton/_triton_kernels/batched_gemm_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,23 @@
import json
import triton
import triton.language as tl
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.kernel_repr import make_kernel_repr


_batched_gemm_bf16_repr = make_kernel_repr(
"_batched_gemm_bf16_kernel",
[
"HAS_BIAS",
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"EVEN_K",
"GRID_MN",
],
)


@triton.heuristics(
Expand All @@ -17,7 +31,7 @@
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
@triton.jit(repr=_batched_gemm_bf16_repr)
def _batched_gemm_bf16_kernel(
# Pointers to matrices
a_ptr,
Expand Down
32 changes: 15 additions & 17 deletions aiter/ops/triton/batched_gemm_a8w8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from typing import Optional
import torch
import triton
import triton.language as tl
import aiter.ops.triton.utils._triton.arch_info as arch_info
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
from aiter.ops.triton._triton_kernels.batched_gemm_a8w8 import (
_batched_gemm_a8w8_kernel,
_get_config,
Expand All @@ -28,22 +25,23 @@ def batched_gemm_a8w8(
config: Optional[dict] = None,
):
"""
Computes the matmul YQ[i] = XQ[i] x WQ[i]T and applies a conversion scale for every i in a given batch.
Optionally, adds a bias to each result.

The conversion scale for each matmul is received in the form of two 1D tensors that are multiplied to form a
2D one before being applied.

Key parameters:
- XQ: Batch tensor XQ with shape (B, M, K).
- WQ: Batch tensor WQ with shape (B, N, K).
- X_scale: First scale batch tensor with shape (B, M, 1).
- W_scale: Second scale batch tensor with shape (B, 1, N).
- Bias: Bias batch tensor with shape (B, 1, N).
- YQ: Output Matrix Y with shape (B, M, N). If this is none, then it's created by this API and returned as output
Computes batched 8 bit matrix multiplication Y[i] = X[i] @ W[i]^T with per-batch scaling.
Each batch element is independently scaled back to higher precision.

Args:
XQ (torch.Tensor): INT8 input batch with shape (B, M, K).
WQ (torch.Tensor): INT8 weight batch with shape (B, N, K), internally transposed.
x_scale (torch.Tensor): Scale for XQ with shape (B, M, 1).
w_scale (torch.Tensor): Scale for WQ with shape (B, 1, N).
bias (Optional[torch.Tensor]): Bias batch with shape (B, 1, N).
dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16).
splitK (Optional[int]): Not supported. Must be None.
YQ (Optional[torch.Tensor]): Pre-allocated output tensor with shape (B, M, N).
config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N,
BLOCK_SIZE_K, GROUP_SIZE_M).

Returns:
- YQ: The output batch tensor with shape (B, M, N).
torch.Tensor: Output batch with shape (B, M, N).
"""
_LOGGER.info(
f"BATCHED_GEMM_A8W8: x={tuple(XQ.shape)} w={tuple(WQ.shape)} x_scale={tuple(x_scale.shape)} w_scale={tuple(w_scale.shape)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from typing import Optional
import torch
import triton
import triton.language as tl
import aiter.ops.triton.utils._triton.arch_info as arch_info
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
from aiter.ops.triton._triton_kernels.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import (
_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_kernel,
_get_config,
Expand All @@ -27,21 +24,26 @@ def batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant(
config: Optional[dict] = None,
):
"""
Computes the matmul YQ[i] = XQ[i] x WQ[i]T and applies a conversion scale for every i in a given batch.
Optionally, adds a bias to each result.
Computes batched 8 bit matrix multiplication Y[i] = X[i] @ W[i]^T with active activation quantization.
X is quantized to INT8 during computation using per-token grouped quantization.
W is pre-quantized INT8 with per-batch-element scaling.

The conversion scale for each matmul is received in the form of two 1D tensors that are multiplied to form a
2D one before being applied.

Key parameters:
- XQ: Batch tensor XQ with shape (B, M, K) if transpose_bm_in == False else (M, B, K).
- WQ: Batch tensor WQ with shape (B, N, K).
- W_scale: Second scale batch tensor with shape (1, ).
- Bias: Bias batch tensor with shape (B, 1, N).
- YQ: Output Matrix Y with shape (B, M, N). If this is none, then it's created by this API and returned as output
Args:
X (torch.Tensor): Higher precision input batch with shape (B, M, K) or (M, B, K) if transpose_bm_in=True.
Quantized to INT8 on-the-fly during GEMM.
WQ (torch.Tensor): Pre-quantized INT8 weight batch with shape (B, N, K), internally transposed.
w_scale (torch.Tensor): Per-batch scale for WQ with shape (1,).
group_size (int): Group size for per-token grouped quantization of X. Must be power of 2.
bias (Optional[torch.Tensor]): Bias batch with shape (B, 1, N).
dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16).
splitK (Optional[int]): Not supported. Must be None.
YQ (Optional[torch.Tensor]): Pre-allocated output tensor with shape (B, M, N) or (M, B, N) if transpose_bm=True.
transpose_bm (Optional[bool]): Transpose batch and M dimensions in output.
transpose_bm_in (Optional[bool]): Transpose batch and M dimensions in input.
config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M).

Returns:
- YQ: The output batch tensor with shape (B, M, N) if transpose_bm == False else (M, B, N).
torch.Tensor: Output batch with shape (B, M, N) or (M, B, N) if transpose_bm=True.
"""

# Check constraints.
Expand Down
27 changes: 14 additions & 13 deletions aiter/ops/triton/batched_gemm_afp4wfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Optional
import torch
import triton
import triton.language as tl
import aiter.ops.triton.utils._triton.arch_info as arch_info
from aiter.ops.triton._triton_kernels.batched_gemm_afp4wfp4 import (
_batched_gemm_afp4_wfp4_kernel,
Expand Down Expand Up @@ -34,20 +33,22 @@ def batched_gemm_afp4wfp4(
config: Optional[dict] = None,
):
"""
Computes the matmul Y = X x W
X and W are e2m1 fp4 tensors.
x_scales and w_scales are e8m0 tensors.
Every 32 elements in the K dimension share one e8m0 scale.


Key parameters:
- X: Matrix X with shape (B, M, K).
- W: Matrix W with shape (B, N, K).
- X_scales: Matrix with shape (B, M, K // 32)
- W_scales: Matrix with shape (B, N, K // 32)
Computes batched FP4 matrix multiplication Y[i] = X[i] @ W[i]^T with FP4 activations and weights.

Args:
x (torch.Tensor): FP4 E2M1 input batch with shape (B, M, K).
w (torch.Tensor): FP4 E2M1 weight batch with shape (B, N, K), internally transposed.
x_scales (torch.Tensor): E8M0 per-group scale for x with shape (B, M, K//32).
One scale per 32 elements in K dimension.
w_scales (torch.Tensor): E8M0 per-group scale for w with shape (B, N, K//32).
One scale per 32 elements in K dimension.
dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16).
y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (B, M, N).
config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N,
BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE).

Returns:
- Y: The output matrix with shape (B, M, N).
torch.Tensor: Output batch with shape (B, M, N).
"""
_LOGGER.info(
f"BATCHED_GEMM_AFP4WFP4: x={tuple(x.shape)} w={tuple(w.shape)} x_scale={tuple(x.shape)} w_scale={tuple(w.shape)}"
Expand Down
Loading
Loading