Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w16.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"NUM_KSPLIT",
"SPLITK_BLOCK_SIZE",
"EVEN_K",
"EVEN_MN",
"cache_modifier",
"activation",
"use_activation",
Expand Down Expand Up @@ -45,6 +46,8 @@
{
"EVEN_K": lambda args: (args["K"] % (args["SPLITK_BLOCK_SIZE"]) == 0)
and (args["SPLITK_BLOCK_SIZE"] % args["BLOCK_SIZE_K"] == 0),
"EVEN_MN": lambda args: (args["M"] % args["BLOCK_SIZE_M"] == 0)
and (args["N"] % args["BLOCK_SIZE_N"] == 0),
}
)
@triton.jit(
Expand Down Expand Up @@ -74,6 +77,7 @@ def _gemm_a16_w16_kernel(
NUM_KSPLIT: tl.constexpr,
SPLITK_BLOCK_SIZE: tl.constexpr,
EVEN_K: tl.constexpr,
EVEN_MN: tl.constexpr,
cache_modifier: tl.constexpr,
activation: tl.constexpr,
use_activation: tl.constexpr,
Expand Down Expand Up @@ -117,8 +121,12 @@ def _gemm_a16_w16_kernel(
# Create pointers for first block of A and B input matrices
offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_k_split = split_k_start + offs_k
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
if EVEN_MN:
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
else:
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N

a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_k_split[None, :] * stride_ak
Expand Down Expand Up @@ -177,8 +185,11 @@ def _gemm_a16_w16_kernel(
+ stride_cn * offs_cn[None, :]
+ pid_k * stride_ck
)
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
if EVEN_MN:
tl.store(c_ptrs, c)
else:
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


@triton.jit(repr=_gemm_a16w16_reduce_repr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"EVEN_K",
"EVEN_MN",
"cache_modifier",
],
)
Expand All @@ -24,6 +25,8 @@
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0,
"EVEN_MN": lambda args: (args["M"] % args["BLOCK_SIZE_M"] == 0)
and (args["N"] % args["BLOCK_SIZE_N"] == 0),
}
)
@triton.jit(
Expand Down Expand Up @@ -64,6 +67,7 @@ def _batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_ker
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
EVEN_K: tl.constexpr,
EVEN_MN: tl.constexpr,
cache_modifier: tl.constexpr,
):
"""
Expand Down Expand Up @@ -136,8 +140,12 @@ def _batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_ker
tl.assume(batch_id >= 0)

offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
if EVEN_MN:
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
else:
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
a_ptrs = a_ptr + (
batch_id * stride_ab
+ offs_am[:, None] * stride_am
Expand Down Expand Up @@ -175,7 +183,10 @@ def _batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_ker
accumulator *= b_scale

if HAS_BIAS:
offs_bias = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
if EVEN_MN:
offs_bias = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
else:
offs_bias = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
bias = tl.load(bias_ptr + batch_id * stride_biasb + offs_bias)
accumulator = accumulator.to(bias_ptr.type.element_ty) + bias[None, :]

Expand All @@ -189,9 +200,11 @@ def _batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant_ker
+ stride_cm * offs_cm[:, None]
+ stride_cn * offs_cn[None, :]
)
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)

tl.store(c_ptrs, c, mask=c_mask)
if EVEN_MN:
tl.store(c_ptrs, c)
else:
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


def _get_config(
Expand Down
Loading