Skip to content

Commit 1c6f82b

Browse files
masahiyongwww
authored andcommitted
Clean up task extraction (tlc-pack#92)
* Clean up taske extraction * black
1 parent b8a7284 commit 1c6f82b

File tree

1 file changed

+12
-49
lines changed

1 file changed

+12
-49
lines changed

python/tvm/meta_schedule/integration.py

Lines changed: 12 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from tvm.tir import PrimFunc
3030
from tvm.relax.expr import Function as RelaxFunc
3131
from tvm.relax.utils import tir_partitioner
32-
from tvm.relax.ty import DynTensorType
3332

3433
from . import _ffi_api
3534
from .database import Database
@@ -248,14 +247,7 @@ def extract_task_from_relay(
248247
return list(reversed(tasks))
249248

250249

251-
def extract_task_from_relax(
252-
mod: Union[IRModule, RelaxFunc],
253-
target: Target,
254-
*,
255-
opt_level: int = 3,
256-
pass_config: Dict[str, DynTensorType] = {},
257-
disabled_pass: List[str] = [],
258-
) -> List[ExtractedTask]:
250+
def extract_task_from_relax(mod: Union[IRModule, RelaxFunc], target: Target) -> List[ExtractedTask]:
259251
"""Extract tuning tasks from a relax program.
260252
261253
Parameters
@@ -264,53 +256,24 @@ def extract_task_from_relax(
264256
The module or function to tune
265257
target : tvm.target.Target
266258
The compilation target
267-
opt_level : int
268-
The optimization level of the compiler
269-
pass_config : Dict[str, DynTensorType]
270-
The pass config of the compiler
271-
disabled_pass : List[str]
272-
The list of disabled passes of the compiler
273259
274260
Returns
275261
-------
276262
tasks: List[ExtractedTask]
277-
The tasks extracted from this network
263+
The tasks extracted from this module
278264
"""
279-
280-
@contextmanager
281-
def _autotvm_silencer():
282-
from tvm import autotvm # pylint: disable=import-outside-toplevel
283-
284-
silent = autotvm.GLOBAL_SCOPE.silent
285-
autotvm.GLOBAL_SCOPE.silent = True
286-
try:
287-
yield
288-
finally:
289-
autotvm.GLOBAL_SCOPE.silent = silent
290-
291-
def _thread_run(func: Callable[[], None]) -> None:
292-
import threading # pylint: disable=import-outside-toplevel
293-
294-
thread = threading.Thread(target=func)
295-
thread.start()
296-
thread.join()
297-
298-
env = TaskExtraction()
299265
if isinstance(mod, RelaxFunc):
300266
mod = IRModule.from_expr(mod)
301267
if not isinstance(target, Target):
302268
target = Target(target)
303269

304-
def _func():
305-
with env, _autotvm_silencer(), transform.PassContext(
306-
config=pass_config,
307-
disabled_pass=disabled_pass,
308-
opt_level=opt_level,
309-
):
310-
tir_partitions = tir_partitioner(mod)
311-
for tir_mod in tir_partitions:
312-
func_name = tir_mod.get_global_vars()[0].name_hint
313-
MetaScheduleContext.query_inside_with_scope(func_name, tir_mod, target, [tir_mod])
314-
315-
_thread_run(_func)
316-
return env.tasks
270+
tir_partitions = tir_partitioner(mod)
271+
272+
tasks = []
273+
for tir_mod in tir_partitions:
274+
task_name = tir_mod.get_global_vars()[0].name_hint
275+
# The second arg to ExtractedTask is supposed to be a high-level IRModule,
276+
# passing tir_mod as a workaround.
277+
tasks.append(ExtractedTask(task_name, tir_mod, target, [tir_mod]))
278+
279+
return tasks

0 commit comments

Comments
 (0)