@@ -194,6 +194,7 @@ def do_compile(func, verbose):
194194 if raise_error :
195195 raise e
196196 return e
197+
197198 fut2idx = {}
198199 futures = []
199200 result = [None for _ in range (len (funcs ))]
@@ -393,7 +394,8 @@ def run(self):
393394 records = []
394395 best_latency , best , best_args = None , None , None
395396 num_errors = 0
396- progress_bar = get_tqdm (zip (self .configs , self .kernels ), total = len (self .configs ), desc = "Benchmarking" )
397+ progress_bar = get_tqdm (
398+ zip (self .configs , self .kernels ), total = len (self .configs ), desc = "Benchmarking" )
397399 for cfg , ker in progress_bar :
398400 const_args , dyn_args = self .arg_parser (* cfg .args , ** cfg .kwargs ) # type: ignore
399401 record = {k : v for k , v in zip (self .arg_parser .const_arg_names , const_args )}
@@ -505,30 +507,31 @@ def tune_configs(self,
505507 result = tuner .run ()
506508 return result
507509
508- def tune (self ,
509- * args : _P .args ,
510- ** kws : _P .kwargs ) -> AutoTuneResult :
510+ def tune (self , * args : _P .args , ** kws : _P .kwargs ) -> AutoTuneResult :
511511 """Tune with the given args, return the tune result
512512
513513 Args: the same as the decorated tilelang kernel
514514
515515 Returns: a object represents the tune result
516516 """
517+
517518 def get_kw_arg (k , default ):
518519 if k in kws :
519520 v = kws [k ]
520521 del kws [k ]
521522 return v
522523 else :
523524 return default
525+
524526 raise_error = get_kw_arg ('_raise_error' , True )
525527 _config = get_kw_arg ('_config' , None )
526528 max_workers = get_kw_arg ('_max_workers' , None )
527529 const_args , _ = self .arg_parser (* args , ** kws )
528530 if const_args in self .tune_cache :
529531 return self .tune_cache [const_args ]
530532 configs = self .get_tune_configs (* args , ** kws )
531- result = self .tune_configs (configs , max_workers = max_workers , _config = _config , raise_error = raise_error )
533+ result = self .tune_configs (
534+ configs , max_workers = max_workers , _config = _config , raise_error = raise_error )
532535 self .tune_cache [const_args ] = result
533536 return result
534537
0 commit comments