Skip to content

Commit ffc1fc0

Browse files
authored
[TVMC] Allow selecting a subset of tasks to be used in tvmc tune (#12525)
This adds a `--tasks` flag to the `tvmc tune` command to filter the lists of tasks to be tuned. See examples below. ## Motivation - As auto-tuning can be quite time consuming, it is often desirable to cut down the number of tuned tasks in a session. - If the tuning session was canceled halfway through, it would be a bad idea to start from scratch. Instead continue with the last untuned task - Some tasks have more impact on the model performance than others, thus we should be able to train some tasks longer than others ## Examples 1. Use `--task list` to show which tasks are available for tuning ``` $ tvmc tune toycar.tflite -o out.txt --task list Available Tasks for tuning: 0. Task(func_name=dense_nopack.x86, args=(('TENSOR', (1, 640), 'int16'), ('TENSOR', (128, 640), 'int... 1. Task(func_name=dense_pack.x86, args=(('TENSOR', (1, 640), 'int16'), ('TENSOR', (128, 640), 'int16... 2. Task(func_name=dense_nopack.x86, args=(('TENSOR', (1, 128), 'int16'), ('TENSOR', (128, 128), 'int... 3. Task(func_name=dense_pack.x86, args=(('TENSOR', (1, 128), 'int16'), ('TENSOR', (128, 128), 'int16... 4. Task(func_name=dense_nopack.x86, args=(('TENSOR', (1, 128), 'int16'), ('TENSOR', (8, 128), 'int16... 5. Task(func_name=dense_pack.x86, args=(('TENSOR', (1, 128), 'int16'), ('TENSOR', (8, 128), 'int16')... 6. Task(func_name=dense_nopack.x86, args=(('TENSOR', (1, 8), 'int16'), ('TENSOR', (128, 8), 'int16')... 7. Task(func_name=dense_pack.x86, args=(('TENSOR', (1, 8), 'int16'), ('TENSOR', (128, 8), 'int16'), ... 8. Task(func_name=dense_nopack.x86, args=(('TENSOR', (1, 128), 'int16'), ('TENSOR', (640, 128), 'int... 9. Task(func_name=dense_pack.x86, args=(('TENSOR', (1, 128), 'int16'), ('TENSOR', (640, 128), 'int16... ``` 2. Filter the list of tasks to be tuned: ``` # Only tune a single task (index 5) tvmc tune toycar.tflite -o out.txt --tasks 5 # Tunes tasks starting with index 6 tvmc tune toycar.tflite -o out.txt --tasks "6-" # Tune tasks 1,4,5,6,8,9 tvmc tune toycar.tflite -o out.txt --tasks "1,4-6,8-" ``` ## Tests I added a basic unit test for the `filter_tasks` utility in `tests/python/driver/tvmc/test_autotuner.py`. ## Open Questions - ~~While the (truncated) string representations of AutoTVM tasks are quite helpful to pick the correct tasks, using AutoScheduler the tasks can not really be distinguished from each other (only by index). Is there a way to get similar information from AutoScheduler tasks?~~
1 parent 8e2382e commit ffc1fc0

File tree

3 files changed

+227
-12
lines changed

3 files changed

+227
-12
lines changed

gallery/tutorial/tvmc_command_line_driver.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,11 @@
412412
# process, in terms of number of repetitions (``--repeat`` and ``--number``, for example), the tuning
413413
# algorithm to be used, and so on. Check ``tvmc tune --help`` for more information.
414414
#
415+
# In some situations it might be a good idea, to only tune specific tasks (i.e. the most relevant ones)
416+
# to waste less time tuning simpler workworloads. The flag `--task` offers versatile options to limt
417+
# the tasks used for tuning, e.g. `--task 20,22` or `--task 16-`. All available tasks can be printed
418+
# using `--task list`.
419+
#
415420

416421
################################################################################
417422
# Compiling an Optimized Model with Tuning Data

python/tvm/driver/tvmc/autotuner.py

Lines changed: 121 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ def add_tune_parser(subparsers, _, json_params):
135135
help="enable tuning the graph through the AutoScheduler tuner",
136136
action="store_true",
137137
)
138+
parser.add_argument(
139+
"--tasks",
140+
default="all",
141+
help="which tasks should be tuned, i.e. 0 0,2 3-5 all list",
142+
)
138143

139144
auto_scheduler_group = parser.add_argument_group(
140145
"AutoScheduler options",
@@ -290,10 +295,100 @@ def drive_tune(args):
290295
include_simple_tasks=args.include_simple_tasks,
291296
log_estimated_latency=args.log_estimated_latency,
292297
additional_target_options=reconstruct_target_args(args),
298+
tasks_filter=args.tasks,
293299
**transform_args,
294300
)
295301

296302

303+
def filter_tasks(
304+
tasks: Union[List[auto_scheduler.SearchTask], List[autotvm.task.Task]],
305+
expr: str,
306+
):
307+
"""Utility to filter a list of tasks (AutoTVM or AutoScheduler) based on
308+
a user-supplied string expression.
309+
310+
Parameters
311+
----------
312+
tasks: list
313+
A list of extracted AutoTVM or AutoScheduler tasks.
314+
expr: str
315+
User-supplied expression to be used for filtering.
316+
"""
317+
assert isinstance(expr, str), "Expected filter expression of string type"
318+
assert len(expr) > 0, "Got empty filter expression"
319+
320+
# groups of keywords are comma-separated
321+
splitted = expr.split(",")
322+
323+
do_list = False
324+
do_filter = False
325+
selected = []
326+
for item in splitted:
327+
if item in ["list", "help"]:
328+
do_list = True
329+
elif item in ["all"]:
330+
selected = list(range(len(tasks)))
331+
else:
332+
do_filter = True
333+
if "-" in item:
334+
assert item.count("-") == 1, "Malformed range expression"
335+
assert len(item) > 1, "Missing lhs or rhs for range expression"
336+
lhs, rhs = item.split("-")[:2]
337+
lhs = int(lhs) if lhs else 0
338+
rhs = int(rhs) if rhs else len(tasks) - 1
339+
assert 0 <= lhs < len(tasks), "Left-hand side expression out of range"
340+
assert 0 <= rhs < len(tasks), "Right-hand side expression out of range"
341+
selected.extend(list(range(lhs, rhs + 1)))
342+
else:
343+
assert isinstance(item, str)
344+
idx = int(item)
345+
assert 0 <= idx < len(tasks), "Task index out of range"
346+
selected.append(idx)
347+
348+
if do_filter:
349+
# remove duplicates
350+
selected = list(set(selected))
351+
tasks = [task for i, task in enumerate(tasks) if i in selected]
352+
353+
return tasks, do_list
354+
355+
356+
def gen_task_list(
357+
tasks: Union[List[auto_scheduler.SearchTask], List[autotvm.task.Task]],
358+
enable_autoscheduler: bool,
359+
):
360+
"""Utility for printing a list of tasks (AutoTVM or AutoScheduler)
361+
to the terminal.
362+
363+
Parameters
364+
----------
365+
tasks: list
366+
A list of extracted AutoTVM or AutoScheduler tasks.
367+
enable_autoscheduler: bool
368+
Wether the tasks are extracted with AutoScheduler or AutoTVM.
369+
"""
370+
ret = "Available Tasks for tuning:\n"
371+
372+
def _trunc_helper(text, length):
373+
return text if len(text) < length else text[: length - 3] + "..."
374+
375+
ret += "\n".join(
376+
[
377+
" {}. {}".format(
378+
i, _trunc_helper("Unnamed" if len(task.desc) == 0 else task.desc, 100)
379+
)
380+
if enable_autoscheduler
381+
else " {}. {} (len={})".format(
382+
i,
383+
_trunc_helper(str(task), 100),
384+
"?" if task.config_space is None else len(task.config_space),
385+
)
386+
for i, task in enumerate(tasks)
387+
]
388+
)
389+
return ret
390+
391+
297392
def tune_model(
298393
tvmc_model: TVMCModel,
299394
target: str,
@@ -316,6 +411,7 @@ def tune_model(
316411
include_simple_tasks: bool = False,
317412
log_estimated_latency: bool = False,
318413
additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None,
414+
tasks_filter: str = "all",
319415
desired_layout: Optional[str] = None,
320416
desired_layout_ops: Optional[List[str]] = None,
321417
mixed_precision: bool = False,
@@ -376,6 +472,9 @@ def tune_model(
376472
If using the autoscheduler, write the estimated latency at each step of tuning to file.
377473
additional_target_options: Optional[Dict[str, Dict[str, Any]]]
378474
Additional target options in a dictionary to combine with initial Target arguments
475+
tasks_filter : str, optional
476+
Filter which tasks should be tuned or output a list of the extracted tasks.
477+
Examples: 0 0,2 3-5 all list
379478
desired_layout: str, optional
380479
Can be one of "NCHW" or "NHWC". When specified, compatible operations in the graph
381480
will have their layout set to this format. Tasks will then be tuned using this
@@ -391,7 +490,6 @@ def tune_model(
391490
mixed_precision_acc_type: str
392491
The accumulation data type to be used while mixed precision.
393492
394-
395493
Returns
396494
-------
397495
tuning_records : str
@@ -464,7 +562,6 @@ def tune_model(
464562
runner = local_server
465563

466564
if enable_autoscheduler:
467-
468565
tasks, weights = autoscheduler_get_tuning_tasks(
469566
mod=mod,
470567
params=params,
@@ -473,7 +570,27 @@ def tune_model(
473570
hardware_params=hardware_params,
474571
include_simple_tasks=include_simple_tasks,
475572
)
573+
else:
574+
tasks = autotvm_get_tuning_tasks(
575+
mod=mod,
576+
params=params,
577+
target=target,
578+
transform_args=transform_args,
579+
)
580+
581+
# Filter extracted tasks by provided user expression
582+
if tasks_filter:
583+
tasks, do_list = filter_tasks(tasks, tasks_filter)
584+
if do_list:
585+
print(gen_task_list(tasks, enable_autoscheduler))
586+
return None
587+
if len(tasks) == 0:
588+
logger.info("No tasks have been selected for tuning.")
589+
return None
590+
else:
591+
logger.info("Selected %s tasks for tuning.", len(tasks))
476592

593+
if enable_autoscheduler:
477594
# Create the autoscheduler tuning options
478595
tuning_options = auto_scheduler.TuningOptions(
479596
num_measure_trials=trials,
@@ -487,16 +604,9 @@ def tune_model(
487604
# Schedule the tasks (i.e., produce a schedule for each task)
488605
schedule_tasks(tasks, weights, tuning_options, prior_records, log_estimated_latency)
489606
else:
490-
tasks = autotvm_get_tuning_tasks(
491-
mod=mod,
492-
params=params,
493-
target=target,
494-
transform_args=transform_args,
495-
)
496-
497607
# In autotvm, trials is specified per task. We can convert the per-model input
498608
# provided to per-task trials by dividing by the number of tasks.
499-
trials = int(trials / max(len(tasks), 1))
609+
trials = int(max(1, trials / max(len(tasks), 1)))
500610
logger.info("Autotuning with %d trials per task.", trials)
501611

502612
tuning_options = {
@@ -710,7 +820,7 @@ def tune_tasks(
710820
early_stopping=early_stopping,
711821
measure_option=measure_option,
712822
callbacks=[
713-
autotvm.callback.progress_bar(trials, prefix=prefix),
823+
autotvm.callback.progress_bar(min(trials, len(tsk.config_space)), prefix=prefix),
714824
autotvm.callback.log_to_file(log_file),
715825
],
716826
)

tests/python/driver/tvmc/test_autotuner.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
from pathlib import Path
2525

2626
import tvm
27-
from tvm import autotvm
27+
import tvm.testing
28+
from tvm import autotvm, auto_scheduler
2829
from tvm.driver import tvmc
30+
from tvm.driver.tvmc.autotuner import filter_tasks, gen_task_list
2931

3032

3133
def _get_tasks(model):
@@ -207,3 +209,101 @@ def test_autotune_pass_context(mock_pc, onnx_mnist, tmpdir_factory):
207209
# AutoTVM overrides the pass context later in the pipeline to disable AlterOpLayout
208210
assert mock_pc.call_count == 2
209211
assert mock_pc.call_args_list[0][1]["opt_level"] == 3
212+
213+
214+
def test_filter_tasks_valid():
215+
filter_tasks(list(range(10)), "list") == ([], True)
216+
filter_tasks(list(range(10)), "help") == ([], True)
217+
filter_tasks(list(range(10)), "all") == ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], False)
218+
filter_tasks(list(range(10)), "5") == ([5], False)
219+
filter_tasks(list(range(10)), "1-5") == ([1, 2, 3, 4, 5], False)
220+
filter_tasks(list(range(10)), "-5") == ([0, 1, 2, 3, 4, 5], False)
221+
filter_tasks(list(range(10)), "6-") == ([6, 7, 8, 9], False)
222+
filter_tasks(list(range(10)), "0,1-3,all") == ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], False)
223+
filter_tasks(list(range(10)), "0,4-5,9,list") == ([0, 4, 5, 9], True)
224+
225+
226+
@pytest.mark.parametrize(
227+
"value,err_msg",
228+
[
229+
("10", "Task index out of range"),
230+
("5,10", "Task index out of range"),
231+
("1-10", "Right-hand side expression out of range"),
232+
("-10", "Right-hand side expression out of range"),
233+
("-", "Missing lhs or rhs for range expression"),
234+
("-10-", "Malformed range expression"),
235+
("--", "Malformed range expression"),
236+
],
237+
)
238+
def test_filter_tasks_invalid(value, err_msg):
239+
with pytest.raises(AssertionError, match=err_msg):
240+
filter_tasks(list(range(10)), value)
241+
242+
243+
@pytest.mark.parametrize(
244+
"enable_autoscheduler,expected",
245+
[
246+
(
247+
False,
248+
"""Available Tasks for tuning:
249+
0. Task(func_name=taskA, args=[], kwargs={}, workload=('taskA',)) (len=?)
250+
1. Task(func_name=taskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBta... (len=?)
251+
2. Task(func_name=taskC, args=[], kwargs={}, workload=('taskC',)) (len=?)""",
252+
),
253+
(
254+
True,
255+
"""Available Tasks for tuning:
256+
0. taskA
257+
1. taskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBtaskBta...
258+
2. Unnamed""",
259+
),
260+
],
261+
)
262+
def test_print_task_list(enable_autoscheduler, expected):
263+
if enable_autoscheduler:
264+
auto_scheduler.search_task.TASK_INPUT_BUFFER_TABLE.clear()
265+
N = 64
266+
target = "llvm"
267+
test_input_0 = tvm.runtime.ndarray.empty((64, 64))
268+
test_input_1 = tvm.runtime.ndarray.empty((10, 20))
269+
test_input_2 = tvm.runtime.ndarray.empty((30, 40, 50))
270+
task_inputs = {
271+
"test_input_0": test_input_0,
272+
"test_input_1": test_input_1,
273+
"test_input_2": test_input_2,
274+
}
275+
task1 = auto_scheduler.SearchTask(
276+
func="matmul_auto_scheduler_test",
277+
args=(N, N, N),
278+
target=target,
279+
task_inputs=task_inputs,
280+
task_inputs_overwrite=True,
281+
desc="taskA",
282+
)
283+
task2 = auto_scheduler.SearchTask(
284+
func="matmul_auto_scheduler_test",
285+
args=(N, N, N),
286+
target=target,
287+
task_inputs=task_inputs,
288+
task_inputs_overwrite=True,
289+
desc="taskB" * 20, # very long name
290+
)
291+
task3 = auto_scheduler.SearchTask(
292+
func="matmul_auto_scheduler_test",
293+
args=(N, N, N),
294+
target=target,
295+
task_inputs=task_inputs,
296+
task_inputs_overwrite=True,
297+
# missing description
298+
)
299+
else:
300+
task1 = autotvm.task.Task("taskA", [])
301+
task2 = autotvm.task.Task("taskB" * 20, []) # very long name
302+
task3 = autotvm.task.Task("taskC", [])
303+
tasks = [task1, task2, task3]
304+
out = gen_task_list(tasks, enable_autoscheduler)
305+
assert out == expected
306+
307+
308+
if __name__ == "__main__":
309+
tvm.testing.main()

0 commit comments

Comments
 (0)