3939from .model import TVMCModel
4040from .target import target_from_cli , generate_target_args , reconstruct_target_args
4141from .shape_parser import parse_shape_string
42- from .transform import convert_graph_layout
42+ from .transform import generate_transform_args , parse_graph_transform_args , apply_graph_transforms
4343
4444
4545# pylint: disable=invalid-name
@@ -127,12 +127,7 @@ def add_tune_parser(subparsers, _, json_params):
127127 metavar = "PATH" ,
128128 help = "path to an auto-tuning log file by AutoTVM." ,
129129 )
130- parser .add_argument (
131- "--desired-layout" ,
132- choices = ["NCHW" , "NHWC" ],
133- default = None ,
134- help = "change the data layout of the whole graph" ,
135- )
130+ generate_transform_args (parser )
136131 parser .add_argument (
137132 "--enable-autoscheduler" ,
138133 help = "enable tuning the graph through the AutoScheduler tuner" ,
@@ -269,6 +264,8 @@ def drive_tune(args):
269264 rpc_hostname = None
270265 rpc_port = None
271266
267+ transform_args = parse_graph_transform_args (args )
268+
272269 tune_model (
273270 tvmc_model ,
274271 args .target ,
@@ -283,7 +280,7 @@ def drive_tune(args):
283280 tuner = args .tuner ,
284281 min_repeat_ms = args .min_repeat_ms ,
285282 early_stopping = args .early_stopping ,
286- desired_layout = args . desired_layout ,
283+ transform_args = transform_args ,
287284 timeout = args .timeout ,
288285 repeat = args .repeat ,
289286 number = args .number ,
@@ -309,7 +306,7 @@ def tune_model(
309306 tuner : str = "xgb" ,
310307 min_repeat_ms : Optional [int ] = None ,
311308 early_stopping : Optional [int ] = None ,
312- desired_layout : Optional [str ] = None ,
309+ transform_args : Optional [Dict [ str , Any ] ] = None ,
313310 timeout : int = 10 ,
314311 repeat : int = 1 ,
315312 number : int = 10 ,
@@ -354,10 +351,8 @@ def tune_model(
354351 Minimum time to run each trial. Defaults to 0 on x86 and 1000 on other targets.
355352 early_stopping : int, optional
356353 When specified, stop tuning after this number of trials if results aren't improving.
357- desired_layout : str, optional
358- Can be one of "NCHW" or "NHWC". When specified, compatible operations in the graph
359- will have their layout set to this format. Tasks will then be tuned using this
360- specified layout.
354+ transform_args: dict, optional
355+ Graph transformation arguments that are applied to the relay module.
361356 timeout : int, optional,
362357 If a kernel trial lasts longer than this duration in seconds, it will be
363358 considered a failure.
@@ -453,7 +448,7 @@ def tune_model(
453448 mod = mod ,
454449 params = params ,
455450 target = target ,
456- alter_layout = desired_layout ,
451+ transform_args = transform_args ,
457452 hardware_params = hardware_params ,
458453 include_simple_tasks = include_simple_tasks ,
459454 )
@@ -475,7 +470,7 @@ def tune_model(
475470 mod = mod ,
476471 params = params ,
477472 target = target ,
478- alter_layout = desired_layout ,
473+ transform_args = transform_args ,
479474 )
480475
481476 # In autotvm, trials is specified per task. We can convert the per-model input
@@ -504,7 +499,7 @@ def autotvm_get_tuning_tasks(
504499 params : Dict [str , tvm .nd .NDArray ],
505500 target : str ,
506501 target_host : Optional [str ] = None ,
507- alter_layout : Optional [str ] = None ,
502+ transform_args : Optional [Dict [ str , Any ] ] = None ,
508503):
509504 """Get the autotvm tuning tasks for a given relay module.
510505
@@ -518,10 +513,8 @@ def autotvm_get_tuning_tasks(
518513 The compilation target.
519514 target_host : str, optional
520515 The compilation target for the host.
521- alter_layout : str, optional
522- The layout to convert the graph to. Note, the convert layout
523- pass doesn't currently guarantee the whole of the graph will
524- be converted to the chosen layout.
516+ transform_args: dict, optional
517+ Graph transformation arguments that are applied to the relay module.
525518
526519 Returns
527520 -------
@@ -530,8 +523,7 @@ def autotvm_get_tuning_tasks(
530523 """
531524 target , target_host = Target .canon_target_and_host (target , target_host )
532525
533- if alter_layout :
534- mod = convert_graph_layout (mod , alter_layout )
526+ mod = apply_graph_transforms (mod , transform_args )
535527
536528 tasks = autotvm .task .extract_from_program (
537529 mod ["main" ],
@@ -547,7 +539,7 @@ def autoscheduler_get_tuning_tasks(
547539 params : Dict [str , tvm .nd .NDArray ],
548540 target : str ,
549541 target_host : Optional [str ] = None ,
550- alter_layout : Optional [str ] = None ,
542+ transform_args : Optional [Dict [ str , Any ] ] = None ,
551543 hardware_params : Optional [HardwareParams ] = None ,
552544 include_simple_tasks : bool = False ,
553545):
@@ -563,10 +555,8 @@ def autoscheduler_get_tuning_tasks(
563555 The compilation target.
564556 target_host : str, optional
565557 The compilation target for the host.
566- alter_layout : str, optional
567- The layout to convert the graph to. Note, the convert layout
568- pass doesn't currently guarantee the whole of the graph will
569- be converted to the chosen layout.
558+ transform_args: dict, optional
559+ Graph transformation arguments that are applied to the relay module.
570560 hardware_params : Optional[HardwareParams]
571561 Hardware parameters used for the search tasks
572562
@@ -579,8 +569,7 @@ def autoscheduler_get_tuning_tasks(
579569 """
580570 target , target_host = Target .canon_target_and_host (target , target_host )
581571
582- if alter_layout :
583- mod = convert_graph_layout (mod , alter_layout )
572+ mod = apply_graph_transforms (mod , transform_args )
584573
585574 # Extract the tasks
586575 tasks , task_weights = auto_scheduler .extract_tasks (
0 commit comments