2929from tvm .tir import PrimFunc
3030from tvm .relax .expr import Function as RelaxFunc
3131from tvm .relax .utils import tir_partitioner
32- from tvm .relax .ty import DynTensorType
3332
3433from . import _ffi_api
3534from .database import Database
@@ -248,14 +247,7 @@ def extract_task_from_relay(
248247 return list (reversed (tasks ))
249248
250249
251- def extract_task_from_relax (
252- mod : Union [IRModule , RelaxFunc ],
253- target : Target ,
254- * ,
255- opt_level : int = 3 ,
256- pass_config : Dict [str , DynTensorType ] = {},
257- disabled_pass : List [str ] = [],
258- ) -> List [ExtractedTask ]:
250+ def extract_task_from_relax (mod : Union [IRModule , RelaxFunc ], target : Target ) -> List [ExtractedTask ]:
259251 """Extract tuning tasks from a relax program.
260252
261253 Parameters
@@ -264,53 +256,24 @@ def extract_task_from_relax(
264256 The module or function to tune
265257 target : tvm.target.Target
266258 The compilation target
267- opt_level : int
268- The optimization level of the compiler
269- pass_config : Dict[str, DynTensorType]
270- The pass config of the compiler
271- disabled_pass : List[str]
272- The list of disabled passes of the compiler
273259
274260 Returns
275261 -------
276262 tasks: List[ExtractedTask]
277- The tasks extracted from this network
263+ The tasks extracted from this module
278264 """
279-
280- @contextmanager
281- def _autotvm_silencer ():
282- from tvm import autotvm # pylint: disable=import-outside-toplevel
283-
284- silent = autotvm .GLOBAL_SCOPE .silent
285- autotvm .GLOBAL_SCOPE .silent = True
286- try :
287- yield
288- finally :
289- autotvm .GLOBAL_SCOPE .silent = silent
290-
291- def _thread_run (func : Callable [[], None ]) -> None :
292- import threading # pylint: disable=import-outside-toplevel
293-
294- thread = threading .Thread (target = func )
295- thread .start ()
296- thread .join ()
297-
298- env = TaskExtraction ()
299265 if isinstance (mod , RelaxFunc ):
300266 mod = IRModule .from_expr (mod )
301267 if not isinstance (target , Target ):
302268 target = Target (target )
303269
304- def _func ():
305- with env , _autotvm_silencer (), transform .PassContext (
306- config = pass_config ,
307- disabled_pass = disabled_pass ,
308- opt_level = opt_level ,
309- ):
310- tir_partitions = tir_partitioner (mod )
311- for tir_mod in tir_partitions :
312- func_name = tir_mod .get_global_vars ()[0 ].name_hint
313- MetaScheduleContext .query_inside_with_scope (func_name , tir_mod , target , [tir_mod ])
314-
315- _thread_run (_func )
316- return env .tasks
270+ tir_partitions = tir_partitioner (mod )
271+
272+ tasks = []
273+ for tir_mod in tir_partitions :
274+ task_name = tir_mod .get_global_vars ()[0 ].name_hint
275+ # The second arg to ExtractedTask is supposed to be a high-level IRModule,
276+ # passing tir_mod as a workaround.
277+ tasks .append (ExtractedTask (task_name , tir_mod , target , [tir_mod ]))
278+
279+ return tasks
0 commit comments