From 9179efe0e2b884345de60033cd5b9a18234f38a7 Mon Sep 17 00:00:00 2001 From: senlyu163 <70838408+senlyu163@users.noreply.github.com> Date: Tue, 16 Dec 2025 10:28:55 +0800 Subject: [PATCH 1/4] Remove JIT decorator from elementwise_add function in examples --- examples/elementwise/example_elementwise_add.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index 464312ced..01c240f92 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -10,7 +10,6 @@ def ref_program(x, y): return x + y -@tilelang.jit(out_idx=[-1]) def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): @T.prim_func def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)): From df34333d4eeec7021a078f583c2fcd6dc21c9c4b Mon Sep 17 00:00:00 2001 From: senlyu163 <70838408+senlyu163@users.noreply.github.com> Date: Tue, 16 Dec 2025 11:09:22 +0800 Subject: [PATCH 2/4] fix kernel compilation without autotune --- examples/elementwise/example_elementwise_add.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index 01c240f92..f16239e69 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -73,7 +73,8 @@ def main(): else: # Default config config = {"block_M": 32, "block_N": 32, "threads": 128} - kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") + program = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") + kernel = tilelang.compile(program, out_idx=[-1], target="cuda") out = kernel(a, b) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) From 05648b773af60f20d5fb165b1376c0416592cd0a Mon Sep 17 00:00:00 2001 From: senlyu163 Date: Tue, 16 Dec 2025 13:10:15 +0800 Subject: [PATCH 3/4] Refactor main function to accept parameters and update tests for autotune option --- .../elementwise/example_elementwise_add.py | 19 +++++++++---------- .../elementwise/test_example_elementwise.py | 4 ++++ 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index f16239e69..9937c23c6 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -56,18 +56,11 @@ def kernel(block_M=None, block_N=None, threads=None): return autotuner.run(warmup=3, rep=20) -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--m", type=int, default=1024) - parser.add_argument("--n", type=int, default=1024) - parser.add_argument("--use_autotune", action="store_true", default=False) - args, _ = parser.parse_known_args() - M, N = args.m, args.n - +def main(M=1024, N=1024, use_autotune=False): a = torch.randn(M, N, dtype=torch.float32, device="cuda") b = torch.randn(M, N, dtype=torch.float32, device="cuda") - if args.use_autotune: + if use_autotune: result = get_best_config(M, N) kernel = result.kernel else: @@ -81,4 +74,10 @@ def main(): if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=1024) + parser.add_argument("--n", type=int, default=1024) + parser.add_argument("--use_autotune", action="store_true", default=False) + args, _ = parser.parse_known_args() + + main(args.m, args.n, args.use_autotune) diff --git a/examples/elementwise/test_example_elementwise.py b/examples/elementwise/test_example_elementwise.py index f1668f4aa..b7bed4a22 100644 --- a/examples/elementwise/test_example_elementwise.py +++ b/examples/elementwise/test_example_elementwise.py @@ -6,5 +6,9 @@ def test_example_elementwise_add(): example_elementwise_add.main() +def test_example_elementwise_add_with_autotune(): + example_elementwise_add.main(use_autotune=True) + + if __name__ == "__main__": tilelang.testing.main() From fc4d2ed8c3e7b3a19bf690efa8265138c7146a8a Mon Sep 17 00:00:00 2001 From: senlyu163 Date: Tue, 16 Dec 2025 18:12:34 +0800 Subject: [PATCH 4/4] Refactor autotune test function for morden style --- .../elementwise/example_elementwise_add.py | 45 +++++-------------- .../elementwise/test_example_elementwise.py | 2 +- 2 files changed, 13 insertions(+), 34 deletions(-) diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index 9937c23c6..72459459b 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -3,13 +3,22 @@ import torch import tilelang import tilelang.language as T -from tilelang.autotuner import AutoTuner def ref_program(x, y): return x + y +def get_configs(): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + threads = [64, 128, 256] + configs = list(itertools.product(block_M, block_N, threads)) + return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] + + +@tilelang.autotune(configs=get_configs()) +@tilelang.jit(out_idx=[-1]) def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): @T.prim_func def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)): @@ -29,45 +38,16 @@ def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T. return elem_add -def get_configs(M, N): - block_M = [64, 128, 256] - block_N = [64, 128, 256] - threads = [64, 128, 256] - configs = list(itertools.product(block_M, block_N, threads)) - return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] - - -def get_best_config(M, N): - def kernel(block_M=None, block_N=None, threads=None): - return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads) - - autotuner = ( - AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N)) - .set_compile_args( - out_idx=[-1], - target="cuda", - ) - .set_profile_args( - supply_type=tilelang.TensorSupplyType.Auto, - ref_prog=ref_program, - skip_check=False, - ) - ) - return autotuner.run(warmup=3, rep=20) - - def main(M=1024, N=1024, use_autotune=False): a = torch.randn(M, N, dtype=torch.float32, device="cuda") b = torch.randn(M, N, dtype=torch.float32, device="cuda") if use_autotune: - result = get_best_config(M, N) - kernel = result.kernel + kernel = elementwise_add(M, N, in_dtype="float32", out_dtype="float32") else: # Default config config = {"block_M": 32, "block_N": 32, "threads": 128} - program = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") - kernel = tilelang.compile(program, out_idx=[-1], target="cuda") + kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") out = kernel(a, b) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) @@ -79,5 +59,4 @@ def main(M=1024, N=1024, use_autotune=False): parser.add_argument("--n", type=int, default=1024) parser.add_argument("--use_autotune", action="store_true", default=False) args, _ = parser.parse_known_args() - main(args.m, args.n, args.use_autotune) diff --git a/examples/elementwise/test_example_elementwise.py b/examples/elementwise/test_example_elementwise.py index b7bed4a22..24f675cd6 100644 --- a/examples/elementwise/test_example_elementwise.py +++ b/examples/elementwise/test_example_elementwise.py @@ -6,7 +6,7 @@ def test_example_elementwise_add(): example_elementwise_add.main() -def test_example_elementwise_add_with_autotune(): +def test_example_elementwise_add_autotune(): example_elementwise_add.main(use_autotune=True)