@@ -242,50 +242,19 @@ def extract_task_from_relay(
242242 if isinstance (param , np .ndarray ):
243243 params [name ] = nd .array (param )
244244
245- with transform .PassContext (opt_level = opt_level ):
246- with target :
247- tasks = extract_task_func (mod , target , params )
248- return tasks
249-
250- @contextmanager
251- def _autotvm_silencer ():
252- from tvm import autotvm # pylint: disable=import-outside-toplevel
253-
254- silent = autotvm .GLOBAL_SCOPE .silent
255- autotvm .GLOBAL_SCOPE .silent = True
256- try :
257- yield
258- finally :
259- autotvm .GLOBAL_SCOPE .silent = silent
260-
261- def _thread_run (func : Callable [[], None ]) -> None :
262- import threading # pylint: disable=import-outside-toplevel
263-
264- thread = threading .Thread (target = func )
265- thread .start ()
266- thread .join ()
267-
268245 if disabled_pass is None :
269246 disabled_pass = []
270- if pass_config is None :
271- pass_config = {"relay.backend.use_meta_schedule" : True }
272247
273- env = TaskExtraction ()
274248 if isinstance (mod , RelayFunc ):
275249 mod = IRModule .from_expr (mod )
276250 if not isinstance (target , Target ):
277251 target = Target (target )
278252
279- def _func ():
280- with env , _autotvm_silencer (), transform .PassContext (
281- config = pass_config ,
282- disabled_pass = disabled_pass ,
283- opt_level = opt_level ,
284- ):
285- compiler = vm .VMCompiler ()
286- if params :
287- compiler .set_params (params )
288- compiler .lower (mod , target )
289-
290- _thread_run (_func )
291- return env .tasks
253+ with transform .PassContext (
254+ opt_level = opt_level ,
255+ config = pass_config ,
256+ disabled_pass = disabled_pass ,
257+ ):
258+ with target :
259+ tasks = extract_task_func (mod , target , params )
260+ return tasks
0 commit comments