66import inspect
77from typing import Dict
88
9- from ..testing import do_bench , do_bench_cudagraph
109from .jit import KernelInterface
1110from .errors import OutOfResources
11+ from .driver import driver
1212
1313
1414class Autotuner (KernelInterface ):
@@ -24,9 +24,10 @@ def __init__(
2424 pre_hook = None ,
2525 post_hook = None ,
2626 prune_configs_by : Dict = None ,
27- warmup = 25 ,
28- rep = 100 ,
27+ warmup = None ,
28+ rep = None ,
2929 use_cuda_graph = False ,
30+ do_bench = None ,
3031 ):
3132 """
3233 :param prune_configs_by: a dict of functions that are used to prune configs, fields:
@@ -88,10 +89,35 @@ def _post_hook(args, exception):
8889 self .base_fn = fn
8990 while not inspect .isfunction (self .base_fn ):
9091 self .base_fn = self .base_fn .fn
91- self .num_warmups = warmup
92- self .num_reps = rep
93- import torch
94- self .use_cuda_graph = use_cuda_graph and torch .cuda .is_available ()
92+
93+ # If we got explicitly called via the old interface, raise a warning
94+ # and proceed with the old behavior.
95+ if warmup is not None or rep is not None or use_cuda_graph :
96+ import warnings
97+ warnings .warn ("warmup, rep, and use_cuda_graph parameters are deprecated. See _ for details." ,
98+ DeprecationWarning )
99+ import torch
100+ if use_cuda_graph and torch .cuda .is_available ():
101+ from ..testing import do_bench_cudagraph
102+ self .do_bench = lambda kernel_call , quantiles : do_bench_cudagraph (
103+ kernel_call ,
104+ rep = rep if rep is not None else 100 ,
105+ quantiles = quantiles ,
106+ )
107+ return
108+
109+ import triton .testing
110+ self .do_bench = lambda kernel_call , quantiles : triton .testing .do_bench (
111+ kernel_call ,
112+ warmup = warmup if warmup is not None else 25 ,
113+ rep = rep if rep is not None else 100 ,
114+ quantiles = quantiles ,
115+ )
116+
117+ if do_bench is None :
118+ self .do_bench = driver .active .get_benchmarker ()
119+ else :
120+ self .do_bench = do_bench
95121
96122 def _bench (self , * args , config , ** meta ):
97123 from ..compiler .errors import CompileTimeAssertionFailure
@@ -125,11 +151,7 @@ def kernel_call():
125151 self .post_hook (args , exception = None )
126152
127153 try :
128- if self .use_cuda_graph :
129- import torch
130- with torch .cuda .stream (torch .cuda .Stream ()):
131- return do_bench_cudagraph (kernel_call , rep = self .num_reps , quantiles = (0.5 , 0.2 , 0.8 ))
132- return do_bench (kernel_call , warmup = self .num_warmups , rep = self .num_reps , quantiles = (0.5 , 0.2 , 0.8 ))
154+ return self .do_bench (kernel_call , quantiles = (0.5 , 0.2 , 0.8 ))
133155 except (OutOfResources , CompileTimeAssertionFailure ):
134156 return [float ("inf" ), float ("inf" ), float ("inf" )]
135157
@@ -262,7 +284,7 @@ def __str__(self):
262284
263285
264286def autotune (configs , key , prune_configs_by = None , reset_to_zero = None , restore_value = None , pre_hook = None , post_hook = None ,
265- warmup = 25 , rep = 100 , use_cuda_graph = False ):
287+ warmup = None , rep = None , use_cuda_graph = False , do_bench = None ):
266288 """
267289 Decorator for auto-tuning a :code:`triton.jit`'d function.
268290
0 commit comments