Skip to content
Merged
42 changes: 12 additions & 30 deletions benchmark/matmul/benchmark_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,36 +89,18 @@ def get_configs(M, N, K, with_roller=False):
for config in configs:
print(config)
else:

block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
policy = [T.GemmWarpPolicy.Square]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
policy,
enable_rasterization,
))

configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"policy": c[5],
"enable_rasteration": c[6], # keep param name for backward-compat
} for c in _configs
]
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[32, 64],
num_stages=[0, 1, 2, 3],
thread_num=[128, 256],
policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return configs


Expand Down
58 changes: 23 additions & 35 deletions benchmark/matmul/benchmark_matmul_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,51 +221,39 @@ def get_configs(M, N, K, with_roller=False):
print(config)
else:

block_rows_warps = [1, 2, 4]
block_col_warps = [1, 2, 4]
warp_row_tiles = [16, 32, 64, 128]
warp_col_tiles = [16, 32, 64, 128]
chunk = [32, 64, 128, 256]
stage = [0, 2]
enable_rasteration = [True, False]
_configs = list(
itertools.product(block_rows_warps, block_col_warps, warp_row_tiles, warp_col_tiles,
chunk, stage, enable_rasteration))
configs = [{
"block_row_warps": c[0],
"block_col_warps": c[1],
"warp_row_tiles": c[2],
"warp_col_tiles": c[3],
"chunk": c[4],
"stage": c[5],
"enable_rasteration": c[6],
} for c in _configs]
iter_params = dict(
block_row_warps=[1, 2, 4],
block_col_warps=[1, 2, 4],
warp_row_tiles=[16, 32, 64, 128],
warp_col_tiles=[16, 32, 64, 128],
chunk=[32, 64, 128, 256],
stage=[0, 2],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]

return configs


def matmul(M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_roller=False):
def matmul(
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_roller=False,
):
"""Create an autotuned tensor core matrix multiplication kernel."""

@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_row_warps",
"block_col_warps",
"warp_row_tiles",
"warp_col_tiles",
"chunk",
"enable_rasteration",
"stage",
],
warmup=3,
rep=5,
ref_prog=ref_program,
skip_check=True,
)
@tl.jit(out_idx=[2],)
def kernel(
Expand Down
41 changes: 12 additions & 29 deletions benchmark/matmul_fp8/benchmark_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,36 +90,19 @@ def get_configs(M, N, K, with_roller=False):
for config in configs:
print(config)
else:
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[64, 128],
num_stages=[0, 1, 2, 3],
thread_num=[128, 256],
policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]

block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [64, 128]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
policy = [T.GemmWarpPolicy.Square]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
policy,
enable_rasterization,
))

configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"policy": c[5],
"enable_rasteration": c[6], # keep param name for backward-compat
} for c in _configs
]
return configs


Expand Down
4 changes: 0 additions & 4 deletions docs/deeplearning_operators/gemv.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,6 @@ def get_best_config(N, K):

@autotune(
configs=get_configs(),
keys=[
"BLOCK_N",
"reduce_threads",
],
warmup=3,
rep=20,
)
Expand Down
Loading
Loading