Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 180 additions & 5 deletions aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ def _gemm_a16wfp4_kernel(

for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter):
b_scales = tl.load(b_scale_ptrs)
# a_scales = tl.full((BLOCK_SIZE_M, BLOCK_SIZE_K//SCALE_GROUP_SIZE), 127, dtype=tl.uint8)
# b_scales = tl.full((BLOCK_SIZE_N, BLOCK_SIZE_K//SCALE_GROUP_SIZE), 127, dtype=tl.uint8)
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
if EVEN_K:
Expand Down Expand Up @@ -179,8 +177,6 @@ def _gemm_a16wfp4_kernel(
+ pid_k * stride_ck
)
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
# if pid == 0:
# tl.device_print("c", c)
if ATOMIC_ADD:
tl.atomic_add(c_ptrs, c, mask=c_mask, sem="relaxed")
else:
Expand All @@ -204,6 +200,182 @@ def _gemm_a16wfp4_kernel(
)


@triton.heuristics(
{
"EVEN_K": lambda args: (args["K"] % (args["BLOCK_SIZE_K"] // 2) == 0)
and (args["SPLITK_BLOCK_SIZE"] % args["BLOCK_SIZE_K"] == 0)
and (args["K"] % (args["SPLITK_BLOCK_SIZE"] // 2) == 0),
"GRID_MN": lambda args: triton.cdiv(args["M"], args["BLOCK_SIZE_M"])
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit(repr=_gemm_a16wfp4_preshuffle_repr)
def _gemm_a16wfp4_preshuffle_kernel(
a_ptr,
b_ptr,
c_ptr,
b_scales_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bn,
stride_bk,
stride_ck,
stride_cm,
stride_cn,
stride_bsn,
stride_bsk,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_KSPLIT: tl.constexpr,
SPLITK_BLOCK_SIZE: tl.constexpr,
EVEN_K: tl.constexpr,
num_warps: tl.constexpr,
num_stages: tl.constexpr,
waves_per_eu: tl.constexpr,
matrix_instr_nonkdim: tl.constexpr,
GRID_MN: tl.constexpr,
PREQUANT: tl.constexpr,
cache_modifier: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A and B inputs are in the microscale fp4 (mxfp4) format.
A_scales and B_scales are in e8m0 format.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""

tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)
tl.assume(stride_bsk > 0)
tl.assume(stride_bsn > 0)

# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid_unified = tl.program_id(axis=0)
pid_k = pid_unified % NUM_KSPLIT
pid = pid_unified // NUM_KSPLIT
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

if NUM_KSPLIT == 1:
pid_m, pid_n = pid_grid(pid, num_pid_m, num_pid_n, GROUP_SIZE_M=GROUP_SIZE_M)
else:
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n

tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
tl.assume(pid_k >= 0)

# We assume 32 elements along K share the same scale.
SCALE_GROUP_SIZE: tl.constexpr = 32

if (pid_k * SPLITK_BLOCK_SIZE // 2) < K:

num_k_iter = tl.cdiv(SPLITK_BLOCK_SIZE // 2, BLOCK_SIZE_K // 2)

# Create pointers for first block of A and B input matrices
# The BLOCK sizes are of the elements and in fp4 we pack 2 per uint8 container.
offs_k_bf16 = tl.arange(0, BLOCK_SIZE_K)
offs_k_split_bf16 = pid_k * SPLITK_BLOCK_SIZE + offs_k_bf16
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_k_split_bf16[None, :] * stride_ak
)

offs_k_shuffle_arr = tl.arange(0, (BLOCK_SIZE_K // 2) * 16)
offs_k_shuffle = pid_k * (SPLITK_BLOCK_SIZE // 2) * 16 + offs_k_shuffle_arr
offs_bn = (pid_n * (BLOCK_SIZE_N // 16) + tl.arange(0, BLOCK_SIZE_N // 16)) % N
b_ptrs = b_ptr + (
offs_bn[:, None] * stride_bn + offs_k_shuffle[None, :] * stride_bk
)
# Create pointers for the first block of A and B scales
offs_bsn = (
pid_n * (BLOCK_SIZE_N // 32) + tl.arange(0, (BLOCK_SIZE_N // 32))
) % N
offs_ks = (pid_k * (SPLITK_BLOCK_SIZE // SCALE_GROUP_SIZE) * 32) + tl.arange(
0, BLOCK_SIZE_K // SCALE_GROUP_SIZE * 32
)
# B scales are N x K even though B operand is K x N.
b_scale_ptrs = (
b_scales_ptr
+ offs_bsn[:, None] * stride_bsn
+ offs_ks[None, :] * stride_bsk
)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for k in range(pid_k * num_k_iter, (pid_k + 1) * num_k_iter):
b_scales = (
tl.load(b_scale_ptrs, cache_modifier=cache_modifier)
.reshape(
BLOCK_SIZE_N // 32,
BLOCK_SIZE_K // SCALE_GROUP_SIZE // 8,
4,
16,
2,
2,
1,
)
.permute(0, 5, 3, 1, 4, 2, 6)
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // SCALE_GROUP_SIZE)
)

# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
if EVEN_K:
a_bf16 = tl.load(a_ptrs)
b = tl.load(b_ptrs, cache_modifier=cache_modifier)

b = (
b.reshape(
1,
BLOCK_SIZE_N // 16,
BLOCK_SIZE_K // 64,
2,
16,
16,
)
.permute(0, 1, 4, 2, 3, 5)
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K // 2)
.trans(1, 0)
)

if PREQUANT:
a, a_scales = _mxfp4_quant_op(a_bf16, BLOCK_SIZE_K, BLOCK_SIZE_M, 32)

accumulator += tl.dot_scaled(a, a_scales, "e2m1", b, b_scales, "e2m1")

# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += (BLOCK_SIZE_K // 2) * 16 * stride_bk
b_scale_ptrs += BLOCK_SIZE_K * stride_bsk

c = accumulator.to(c_ptr.type.element_ty)

# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
c_ptrs = (
c_ptr
+ stride_cm * offs_cm[:, None]
+ stride_cn * offs_cn[None, :]
+ pid_k * stride_ck
)
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


def get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int):
# heuristics for make "EVEN_K == True" as much as possible
NUM_KSPLIT_STEP = 2
Expand Down Expand Up @@ -241,6 +413,9 @@ def _get_config(
M: int,
N: int,
K: int,
shuffle: bool = False,
):
shuffle_suffix = "_PRESHUFFLED" if shuffle else ""
config_name = f"GEMM-A16WFP4{shuffle_suffix}"
# Note: Config files use K=2*K in their naming
return get_gemm_config("GEMM-A16WFP4", M, N, 2 * K)
return get_gemm_config(config_name, M, N, 2 * K)
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
{
"M_LEQ_8": {
"BLOCK_SIZE_M": 8,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 14
},
"M_LEQ_16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"cache_modifier": ".cg",
"NUM_KSPLIT": 14
},
"M_LEQ_32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 14
},
"M_LEQ_64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 14
},
"M_LEQ_128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 14
},
"M_LEQ_256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 14
},
"any": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 1
}
}
14 changes: 14 additions & 0 deletions aiter/ops/triton/configs/gemm/gfx950-GEMM-A16WFP4_PRESHUFFLED.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"any": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 512,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 1,
"waves_per_eu": 2,
"matrix_instr_nonkdim": 16,
"cache_modifier": null,
"NUM_KSPLIT": 1
}
}
Loading