Skip to content

Commit d6f96e6

Browse files
authored
[AutoTune] Support with set_autotune_inputs to set auto tuning input tensors (#632)
* [Refactor] Simplify and modularize autotuner implementation - Removed unused imports and extensive code sections from the autotuner module to enhance readability and maintainability. - Modularized the code by introducing new imports for autotuning and capturing functionalities, streamlining the overall structure. - Improved logging setup and removed redundant timeout handling functions, focusing on core autotuning logic. - Updated the AutoTuner class to better utilize the new modular structure, ensuring efficient performance during auto-tuning processes. * [Refactor] Clean up and enhance capture and tuner modules - Improved code readability by removing unnecessary blank lines and organizing imports in `capture.py` and `tuner.py`. - Enhanced logging in the `AutoTuner` class to provide clearer warnings regarding the usage of `supply_prog` in the context of auto-tuning. - Streamlined the `CaptureStack` class for better thread-local context management. * lint fix * [Refactor] Simplify configuration and autotuning logic in blocksparse GEMM example - Updated `get_configs` function to reduce the number of configurations, enhancing performance and clarity. - Removed the `get_best_config` function, integrating its logic directly into the `blocksparse_matmul` function with the `@autotune` decorator for streamlined autotuning. - Adjusted the main function to directly utilize the autotuned kernel, simplifying the overall structure and improving readability. - Deleted obsolete test file for autotuning decorator, cleaning up the codebase. * [Refactor] Improve code formatting and readability in autotune test file - Reformatted the `matmul` function and `get_configs` function for better readability by adjusting line breaks and indentation. - Fixed a typo in the `enable_rasteration` parameter name to ensure consistency. - Cleaned up unnecessary blank lines to enhance overall code clarity. * Update example_blocksparse_gemm.py * Update capture.py
1 parent ce6c23d commit d6f96e6

File tree

7 files changed

+1079
-1110
lines changed

7 files changed

+1079
-1110
lines changed

examples/blocksparse_gemm/example_blocksparse_gemm.py

Lines changed: 7 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import itertools
66
import tilelang
77
import tilelang.language as T
8-
from tilelang.autotuner import AutoTuner
98
from tilelang.engine.param import KernelParam
109
from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType
1110
import torch
@@ -37,7 +36,7 @@
3736
print(f"Using Autotuner: {use_autotune}\n")
3837

3938

40-
def get_configs(M, N, K):
39+
def get_configs():
4140
block_M = [64, 128, 256]
4241
block_N = [64, 128, 256]
4342
block_K = [32, 64]
@@ -93,55 +92,7 @@ def supply_program(params: List[KernelParam]):
9392
return input_tensors
9493

9594

96-
def get_best_config(M, N, K):
97-
98-
# Define the kernel function to be tuned.
99-
# Parameters like block_M, block_N, etc., are tuned by the AutoTuner.
100-
def kernel(block_M=None,
101-
block_N=None,
102-
block_K=None,
103-
num_stages=None,
104-
thread_num=None,
105-
enable_rasteration=None):
106-
return blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num,
107-
enable_rasteration)
108-
109-
autotuner = AutoTuner.from_kernel(
110-
kernel=kernel, configs=get_configs(M, N, K)
111-
).set_compile_args(
112-
out_idx=[-1], # Index of the output tensor
113-
target="auto", # Automatically detect target
114-
).set_profile_args(
115-
# supply_type should not set here because we provide a custom supply
116-
# function `supply_prog` and `supply_type` will be ignored.
117-
118-
# supply_prog: Provide the custom function to generate input tensors
119-
# (A, B, BlockMask) for the kernel, allowing controlling sparsity via
120-
# BlockMask generation.
121-
supply_prog=supply_program,
122-
123-
# ref_prog: Using dense matmul (A @ B) as a placeholder reference.
124-
# The 'correct' block-sparse reference (`ref_program` above) requires
125-
# block_M, block_N, block_K parameters. However, these parameters are
126-
# part of the configuration being *tuned* by the AutoTuner and cannot
127-
# be fixed inputs to a static `ref_prog` function signature.
128-
# This dense matmul serves only as a performance baseline.
129-
ref_prog=lambda A, B, BlockMask: A @ B,
130-
131-
# skip_check: Set to True because the provided `ref_prog` does not
132-
# compute the correct result for the block-sparse kernel.
133-
skip_check=True,
134-
135-
# cache_input_tensors: Set to False because the shape of the BlockMask tensor
136-
# (dependent on block_M, block_N, block_K being tuned) changes between
137-
# different configurations. Reusing cached tensors from a previous
138-
# configuration would lead to shape mismatches.
139-
cache_input_tensors=False,
140-
)
141-
# Run the tuning process
142-
return autotuner.run(warmup=3, rep=20)
143-
144-
95+
@tilelang.autotune(configs=get_configs(),)
14596
@tilelang.jit(out_idx=[-1])
14697
def blocksparse_matmul(M,
14798
N,
@@ -195,22 +146,16 @@ def main():
195146
# Run the autotuner to find the best kernel configuration and performance
196147
# get_best_config is expected to return an object containing the compiled kernel,
197148
# the best configuration found, latency, and reference latency.
198-
result = get_best_config(M, N, K)
149+
kernel = blocksparse_matmul(M, N, K)
199150

200-
# Extract results from the autotuner run
201-
kernel = result.kernel
202-
best_config = result.config
203-
block_M = best_config[0]
204-
block_N = best_config[1]
205-
block_K = best_config[2]
206-
best_latency = result.latency
207-
ref_latency = result.ref_latency
151+
best_config = kernel.config
152+
best_latency = kernel.latency
153+
block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config[
154+
"block_K"]
208155

209156
print(f"Best Config: {best_config}")
210-
print(f"Block Dimensions (BM, BN, BK): ({block_M}, {block_N}, {block_K})")
211157
print(f"Sparsity Ratio: {sparsity}")
212158
print(f"Best Kernel Latency: {best_latency:.6f} ms")
213-
print(f"Reference Latency: {ref_latency:.6f} ms")
214159
else:
215160
kernel = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K,
216161
DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM,

examples/convolution/example_convolution_autotune.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import torch
44
import argparse
55
import itertools
6-
import tilelang as tl
7-
from tilelang.autotuner import *
6+
import tilelang
87
import tilelang.language as T
98
from tilelang.autotuner import AutoTuner
109
from tilelang.carver.template import ConvTemplate
@@ -167,7 +166,7 @@ def main(
167166
out_idx=[2],
168167
target="auto",
169168
).set_profile_args(
170-
supply_type=tl.TensorSupplyType.Integer,
169+
supply_type=tilelang.TensorSupplyType.Integer,
171170
ref_prog=ref_prog,
172171
skip_check=False,
173172
)
@@ -301,9 +300,9 @@ def main(n: int = 128,
301300
kernel = result.kernel
302301
else:
303302
config = get_heuristic_config()
304-
kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_idx=[2])
303+
kernel = tilelang.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_idx=[2])
305304

306-
profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto)
305+
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
307306
tilelang_latency = profiler.do_bench()
308307
ref_latency = profiler.do_bench(ref_prog)
309308
profiler.assert_allclose(ref_prog, atol=1e-2, rtol=1e-2)

testing/python/autotune/test_tilelang_autotune_decorator.py

Lines changed: 0 additions & 265 deletions
This file was deleted.

0 commit comments

Comments
 (0)