4848logger = logging .getLogger ("auto_scheduler" )
4949
5050
51- def call_all_topi_funcs (mod , params , target , opt_level = 3 ):
51+ def call_all_topi_funcs (mod , params , target , error_list , opt_level = 3 ):
5252 """Call all TOPI compute to extract auto_scheduler tasks in a Relay program"""
5353 # pylint: disable=import-outside-toplevel
5454 from tvm import relay
@@ -71,7 +71,7 @@ def call_all_topi_funcs(mod, params, target, opt_level=3):
7171 try :
7272 compiler .lower (mod , target )
7373 except TVMError :
74- logger . warning ( "Got exception in task extraction: \n %s" , traceback .format_exc ())
74+ error_list . append ( f" { traceback .format_exc ()} " )
7575 finally :
7676 autotvm .GLOBAL_SCOPE .silent = old_autotvm_silent
7777
@@ -131,14 +131,21 @@ def extract_tasks(
131131 dispatch_ctx = DispatchContext .current
132132 old_verbose = dispatch_ctx .verbose
133133 dispatch_ctx .verbose = 0
134+
135+ errors = []
134136 with env :
135137 # Wrap build call in a new thread to avoid the conflict
136138 # between python's multiprocessing and tvm's thread pool
137139 build_thread = threading .Thread (
138- target = call_all_topi_funcs , args = (mod , params , target , opt_level )
140+ target = call_all_topi_funcs , args = (mod , params , target , errors , opt_level )
139141 )
140142 build_thread .start ()
141143 build_thread .join ()
144+
145+ if errors :
146+ error_strings = ["Task extraction had the following errors:" ] + errors
147+ raise TVMError ("\n " .join (error_strings ))
148+
142149 dispatch_ctx .verbose = old_verbose
143150
144151 # create search tasks
0 commit comments