Skip to content
Merged
232 changes: 91 additions & 141 deletions benchmark/matmul/benchmark_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def ref_program(A, B):
return A @ B.T


def get_configs(M, N, K, with_roller=False):
def get_configs(args, kwargs):
"""
Generate a list of configuration dictionaries that will be used for tuning.

Expand All @@ -47,6 +47,8 @@ def get_configs(M, N, K, with_roller=False):
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
M, N, K, with_roller = args[:4]

if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
Expand Down Expand Up @@ -89,40 +91,40 @@ def get_configs(M, N, K, with_roller=False):
for config in configs:
print(config)
else:

block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
policy = [T.GemmWarpPolicy.Square]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
policy,
enable_rasterization,
))

configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"policy": c[5],
"enable_rasteration": c[6], # keep param name for backward-compat
} for c in _configs
]
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[32, 64],
num_stages=[0, 1, 2, 3],
thread_num=[128, 256],
policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return configs


def matmul(M, N, K, with_roller):
@autotune(
configs=get_configs,
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
def matmul(
M,
N,
K,
with_roller,
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
Expand All @@ -149,118 +151,66 @@ def matmul(M, N, K, with_roller):
The baseline latency of the reference program (for computing speedup).
"""

# Decorate the kernel with autotune & jit, specifying:
# - Tuning config list
# - Profiling keys
# - Warmup and repetition counts for better measurement
# - A reference program for correctness verification
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)

@autotune(
configs=get_configs(M, N, K, with_roller),
warmup=3,
rep=20,
ref_prog=ref_program,
)
@jit(out_idx=[2],)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"

@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The actual kernel to compute C = A @ B^T.

Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.

Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
The compiled TVM function for block-level matrix multiplication.

- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"

@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.

- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):

# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Allocate a shared memory for C sub-block of shape (block_M, block_N)
C_shared = T.alloc_shared((block_M, block_N), dtype)

# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)

# Clear out the accumulation buffer
T.clear(C_local)

# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(A[by * block_M, k * block_K], A_shared)
# Load a sub-block of B from global memory into B_shared
T.copy(B[bx * block_N, k * block_K], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
policy=policy,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])

return main

return kernel()
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):

# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Allocate a shared memory for C sub-block of shape (block_M, block_N)
C_shared = T.alloc_shared((block_M, block_N), dtype)

# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)

# Clear out the accumulation buffer
T.clear(C_local)

# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(A[by * block_M, k * block_K], A_shared)
# Load a sub-block of B from global memory into B_shared
T.copy(B[bx * block_N, k * block_K], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
policy=policy,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])

return main


if __name__ == "__main__":
Expand Down
92 changes: 41 additions & 51 deletions benchmark/matmul/benchmark_matmul_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def ref_program(A, B):
return A @ B.T


def get_configs(M, N, K, with_roller=False):
def get_configs(args, kwargs):
"""
Generate a list of configuration dictionaries that will be used for tuning.

Expand All @@ -180,6 +180,9 @@ def get_configs(M, N, K, with_roller=False):
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
M, N, K = args[:3]
with_roller = args[6]

if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
Expand Down Expand Up @@ -221,62 +224,49 @@ def get_configs(M, N, K, with_roller=False):
print(config)
else:

block_rows_warps = [1, 2, 4]
block_col_warps = [1, 2, 4]
warp_row_tiles = [16, 32, 64, 128]
warp_col_tiles = [16, 32, 64, 128]
chunk = [32, 64, 128, 256]
stage = [0, 2]
enable_rasteration = [True, False]
_configs = list(
itertools.product(block_rows_warps, block_col_warps, warp_row_tiles, warp_col_tiles,
chunk, stage, enable_rasteration))
configs = [{
"block_row_warps": c[0],
"block_col_warps": c[1],
"warp_row_tiles": c[2],
"warp_col_tiles": c[3],
"chunk": c[4],
"stage": c[5],
"enable_rasteration": c[6],
} for c in _configs]
iter_params = dict(
block_row_warps=[1, 2, 4],
block_col_warps=[1, 2, 4],
warp_row_tiles=[16, 32, 64, 128],
warp_col_tiles=[16, 32, 64, 128],
chunk=[32, 64, 128, 256],
stage=[0, 2],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]

return configs


def matmul(M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_roller=False):
@autotune(
configs=get_configs,
warmup=3,
rep=5,
ref_prog=ref_program,
skip_check=True,
)
@tl.jit(out_idx=[2],)
def matmul(
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_roller=False,
block_row_warps=None,
block_col_warps=None,
warp_row_tiles=None,
warp_col_tiles=None,
chunk=None,
stage=None,
enable_rasteration=None,
):
"""Create an autotuned tensor core matrix multiplication kernel."""

@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_row_warps",
"block_col_warps",
"warp_row_tiles",
"warp_col_tiles",
"chunk",
"enable_rasteration",
"stage",
],
warmup=3,
rep=5,
)
@tl.jit(out_idx=[2],)
def kernel(
block_row_warps=None,
block_col_warps=None,
warp_row_tiles=None,
warp_col_tiles=None,
chunk=None,
stage=None,
enable_rasteration=None,
):
def kernel():
return tl_matmul(
M,
N,
Expand Down
Loading
Loading