Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 18 additions & 40 deletions examples/elementwise/example_elementwise_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
import torch
import tilelang
import tilelang.language as T
from tilelang.autotuner import AutoTuner


def ref_program(x, y):
return x + y


def get_configs():
block_M = [64, 128, 256]
block_N = [64, 128, 256]
threads = [64, 128, 256]
configs = list(itertools.product(block_M, block_N, threads))
return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs]


@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func
Expand All @@ -30,47 +38,12 @@ def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.
return elem_add


def get_configs(M, N):
block_M = [64, 128, 256]
block_N = [64, 128, 256]
threads = [64, 128, 256]
configs = list(itertools.product(block_M, block_N, threads))
return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs]


def get_best_config(M, N):
def kernel(block_M=None, block_N=None, threads=None):
return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads)

autotuner = (
AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N))
.set_compile_args(
out_idx=[-1],
target="cuda",
)
.set_profile_args(
supply_type=tilelang.TensorSupplyType.Auto,
ref_prog=ref_program,
skip_check=False,
)
)
return autotuner.run(warmup=3, rep=20)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=1024)
parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False)
args, _ = parser.parse_known_args()
M, N = args.m, args.n

def main(M=1024, N=1024, use_autotune=False):
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda")

if args.use_autotune:
result = get_best_config(M, N)
kernel = result.kernel
if use_autotune:
kernel = elementwise_add(M, N, in_dtype="float32", out_dtype="float32")
else:
# Default config
config = {"block_M": 32, "block_N": 32, "threads": 128}
Expand All @@ -81,4 +54,9 @@ def main():


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=1024)
parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False)
args, _ = parser.parse_known_args()
main(args.m, args.n, args.use_autotune)
4 changes: 4 additions & 0 deletions examples/elementwise/test_example_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,9 @@ def test_example_elementwise_add():
example_elementwise_add.main()


def test_example_elementwise_add_autotune():
example_elementwise_add.main(use_autotune=True)


if __name__ == "__main__":
tilelang.testing.main()
Loading