Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 246 additions & 6 deletions aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w8_blockscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,41 @@
import triton.language as tl
from aiter.ops.triton._triton_kernels.quant.fused_fp8_quant import _fp8_quant_op
from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid
from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr
from aiter.ops.triton.utils.gemm_config_utils import get_gemm_config

import triton


_gemm_a16w8_blockscale_repr = make_kernel_repr(
"_gemm_a16w8_blockscale_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"num_warps",
"num_stages",
"waves_per_eu",
"matrix_instr_nonkdim",
"cache_modifier",
"NUM_KSPLIT",
"SPLITK_BLOCK_SIZE",
"EVEN_K",
"GRID_MN",
"PREQUANT",
],
)


@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0,
"GRID_MN": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"])
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
@triton.jit(repr=_gemm_a16w8_blockscale_repr)
def _gemm_a16w8_blockscale_kernel(
# Pointers to matrices
a_ptr,
Expand Down Expand Up @@ -54,6 +76,10 @@ def _gemm_a16w8_blockscale_kernel(
PREQUANT: tl.constexpr,
DTYPE_MAX: tl.constexpr,
DTYPE_MIN: tl.constexpr,
num_warps: tl.constexpr,
num_stages: tl.constexpr,
waves_per_eu: tl.constexpr,
matrix_instr_nonkdim: tl.constexpr,
cache_modifier: tl.constexpr,
):
"""
Expand Down Expand Up @@ -185,10 +211,224 @@ def _gemm_a16w8_blockscale_kernel(
tl.store(c_ptrs, c, mask=c_mask)


def _get_config(
M: int,
N: int,
K: int,
_gemm_a16w8_blockscale_preshuffle_repr = make_kernel_repr(
"_gemm_a16w8_blockscale_preshuffle_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"num_warps",
"num_stages",
"waves_per_eu",
"matrix_instr_nonkdim",
"cache_modifier",
"NUM_KSPLIT",
"SPLITK_BLOCK_SIZE",
"EVEN_K",
"GRID_MN",
"PREQUANT",
],
)


@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0,
"GRID_MN": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"])
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit(repr=_gemm_a16w8_blockscale_repr)
def _gemm_a16w8_blockscale_preshuffle_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
b_scale_ptr,
# Matrix dimensions
M,
N,
K,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_bn,
stride_bk,
stride_ck,
stride_cm,
stride_cn,
stride_bscale_n,
stride_bscale_k,
# Meta-parameters
GROUP_K: tl.constexpr,
GROUP_N: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_KSPLIT: tl.constexpr,
SPLITK_BLOCK_SIZE: tl.constexpr,
EVEN_K: tl.constexpr,
GRID_MN: tl.constexpr,
PREQUANT: tl.constexpr,
DTYPE_MAX: tl.constexpr,
DTYPE_MIN: tl.constexpr,
num_warps: tl.constexpr,
num_stages: tl.constexpr,
waves_per_eu: tl.constexpr,
matrix_instr_nonkdim: tl.constexpr,
cache_modifier: tl.constexpr,
):
"""
Note: this is Triton jited function and not meant to be called directly. Call gemm_a8w8_blockscale function
below

Computes the 8 bit matmul C = A x B using the block-scale quantization approach.

Key parameters:
- A: Matrix A with shape (M, K).
- B: Matrix B with shape (K, N).
- C: Matrix C with shape (M, N).
- A_scale: Scale tensor for A with shape (M, *scale_k).
- B_scale: Scale tensor for B with shape (*scale_k, **scale_n).

*scale_k = (K + GROUP_K - 1) // GROUP_K
**scale_n = (N + GROUP_N - 1) // GROUP_N
"""

tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_ck > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)
tl.assume(stride_bscale_k > 0)
tl.assume(stride_bscale_n > 0)

# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid_unified = tl.program_id(axis=0)
pid_k = pid_unified % NUM_KSPLIT
pid = pid_unified // NUM_KSPLIT
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

if NUM_KSPLIT == 1:
pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M)
else:
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n

tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
tl.assume(pid_k >= 0)

if (pid_k * SPLITK_BLOCK_SIZE) < K:

# SPLITK_BLOCK_SIZE = tl.cdiv(K, NUM_KSPLIT)
num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE, BLOCK_SIZE_K)

# Create pointers for first block of A and B input matrices
offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_k_shuffle_arr = tl.arange(0, BLOCK_SIZE_K * 16)
offs_k_split = pid_k * SPLITK_BLOCK_SIZE + offs_k
offs_k_shuffle = pid_k * SPLITK_BLOCK_SIZE * 16 + offs_k_shuffle_arr

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * (BLOCK_SIZE_N // 16) + tl.arange(0, BLOCK_SIZE_N // 16)) % (
N // 16
)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak
)
b_ptrs = b_ptr + (
offs_bn[:, None] * stride_bn + offs_k_shuffle[None, :] * stride_bk
)

# Create pointers for the scales
offs_bsn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k_scale = (pid_k * SPLITK_BLOCK_SIZE) // GROUP_K
offs_b_scale_n = offs_bsn // GROUP_N
b_scale_ptrs = (
b_scale_ptr
+ offs_k_scale * stride_bscale_k
+ offs_b_scale_n * stride_bscale_n
)
offs_ks_step = BLOCK_SIZE_K // GROUP_K

acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)

for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs, cache_modifier=cache_modifier)

b = (
b.reshape(
1,
BLOCK_SIZE_N // 16,
BLOCK_SIZE_K // 32,
2,
16,
16,
)
.permute(0, 1, 4, 2, 3, 5)
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
.trans(1, 0)
)

b_scale = tl.load(b_scale_ptrs)

if PREQUANT:
a, a_scale = _fp8_quant_op(
a, BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_K, DTYPE_MAX, DTYPE_MIN
)
a = a.to(b_ptr.type.element_ty).reshape(BLOCK_SIZE_M, BLOCK_SIZE_K)
a_scale = a_scale.reshape(BLOCK_SIZE_M)
accumulator += (
tl.dot(a, b, input_precision="ieee")
* a_scale[:, None]
* b_scale[None, :]
)
else:
b = b.to(a_ptr.type.element_ty)
accumulator += tl.dot(a, b, input_precision="ieee") * b_scale[None, :]

# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * 16 * stride_bk

# k_cur = k * BLOCK_SIZE_K // GROUP_K
# k_nxt = (k + 1) * BLOCK_SIZE_K // GROUP_K
# offs_ks = k_nxt - k_cur
b_scale_ptrs += offs_ks_step * stride_bscale_k

c = accumulator.to(c_ptr.type.element_ty)

# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
c_ptrs = (
c_ptr
+ stride_cm * offs_cm[:, None]
+ stride_cn * offs_cn[None, :]
+ pid_k * stride_ck
)
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


def _get_config(M: int, N: int, K: int, shuffle: bool = False):
shuffle_suffix = "_PRESHUFFLED" if shuffle else ""
config_name = f"GEMM-A16W8_BLOCKSCALE{shuffle_suffix}"

return get_gemm_config("GEMM-A16W8_BLOCKSCALE", M, N, K)
return get_gemm_config(config_name, M, N, K)
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"any": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 1
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
{
"M_LEQ_8": {
"BLOCK_SIZE_M": 8,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 3,
"waves_per_eu": 8,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 14
},
"M_LEQ_16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 6,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 14
},
"M_LEQ_32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 8,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 8
},
"M_LEQ_64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 7
},
"M_LEQ_128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 3,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 7
},
"M_LEQ_256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 7
},
"any": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 1
}
}
Loading