Skip to content

Commit 670c523

Browse files
committed
[TVMC][TRANSFORMS] ToMixedPrecision transform support with custom options enabled
Adds new command line options --mixed-precision --mixed-precision-ops --mixed-precision-input --mixed-precision-output and --desired-layout-ops This PR also enhances the python interface by replacing alter_layout to transform_args. transform_args is a dict with all tranform related options including existing desired_layout or alter_layout option.
1 parent e7ad4bc commit 670c523

File tree

5 files changed

+276
-58
lines changed

5 files changed

+276
-58
lines changed

python/tvm/driver/tvmc/autotuner.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from .model import TVMCModel
4040
from .target import target_from_cli, generate_target_args, reconstruct_target_args
4141
from .shape_parser import parse_shape_string
42-
from .transform import convert_graph_layout
42+
from .transform import generate_transform_args, parse_graph_transform_args, apply_graph_transforms
4343

4444

4545
# pylint: disable=invalid-name
@@ -127,12 +127,7 @@ def add_tune_parser(subparsers, _, json_params):
127127
metavar="PATH",
128128
help="path to an auto-tuning log file by AutoTVM.",
129129
)
130-
parser.add_argument(
131-
"--desired-layout",
132-
choices=["NCHW", "NHWC"],
133-
default=None,
134-
help="change the data layout of the whole graph",
135-
)
130+
generate_transform_args(parser)
136131
parser.add_argument(
137132
"--enable-autoscheduler",
138133
help="enable tuning the graph through the AutoScheduler tuner",
@@ -269,6 +264,8 @@ def drive_tune(args):
269264
rpc_hostname = None
270265
rpc_port = None
271266

267+
transform_args = parse_graph_transform_args(args)
268+
272269
tune_model(
273270
tvmc_model,
274271
args.target,
@@ -283,7 +280,7 @@ def drive_tune(args):
283280
tuner=args.tuner,
284281
min_repeat_ms=args.min_repeat_ms,
285282
early_stopping=args.early_stopping,
286-
desired_layout=args.desired_layout,
283+
transform_args=transform_args,
287284
timeout=args.timeout,
288285
repeat=args.repeat,
289286
number=args.number,
@@ -309,7 +306,7 @@ def tune_model(
309306
tuner: str = "xgb",
310307
min_repeat_ms: Optional[int] = None,
311308
early_stopping: Optional[int] = None,
312-
desired_layout: Optional[str] = None,
309+
transform_args: Optional[Dict[str, Any]] = None,
313310
timeout: int = 10,
314311
repeat: int = 1,
315312
number: int = 10,
@@ -354,10 +351,8 @@ def tune_model(
354351
Minimum time to run each trial. Defaults to 0 on x86 and 1000 on other targets.
355352
early_stopping : int, optional
356353
When specified, stop tuning after this number of trials if results aren't improving.
357-
desired_layout : str, optional
358-
Can be one of "NCHW" or "NHWC". When specified, compatible operations in the graph
359-
will have their layout set to this format. Tasks will then be tuned using this
360-
specified layout.
354+
transform_args: dict, optional
355+
Graph transformation arguments that are applied to the relay module.
361356
timeout : int, optional,
362357
If a kernel trial lasts longer than this duration in seconds, it will be
363358
considered a failure.
@@ -453,7 +448,7 @@ def tune_model(
453448
mod=mod,
454449
params=params,
455450
target=target,
456-
alter_layout=desired_layout,
451+
transform_args=transform_args,
457452
hardware_params=hardware_params,
458453
include_simple_tasks=include_simple_tasks,
459454
)
@@ -475,7 +470,7 @@ def tune_model(
475470
mod=mod,
476471
params=params,
477472
target=target,
478-
alter_layout=desired_layout,
473+
transform_args=transform_args,
479474
)
480475

481476
# In autotvm, trials is specified per task. We can convert the per-model input
@@ -504,7 +499,7 @@ def autotvm_get_tuning_tasks(
504499
params: Dict[str, tvm.nd.NDArray],
505500
target: str,
506501
target_host: Optional[str] = None,
507-
alter_layout: Optional[str] = None,
502+
transform_args: Optional[Dict[str, Any]] = None,
508503
):
509504
"""Get the autotvm tuning tasks for a given relay module.
510505
@@ -518,10 +513,8 @@ def autotvm_get_tuning_tasks(
518513
The compilation target.
519514
target_host : str, optional
520515
The compilation target for the host.
521-
alter_layout : str, optional
522-
The layout to convert the graph to. Note, the convert layout
523-
pass doesn't currently guarantee the whole of the graph will
524-
be converted to the chosen layout.
516+
transform_args: dict, optional
517+
Graph transformation arguments that are applied to the relay module.
525518
526519
Returns
527520
-------
@@ -530,8 +523,7 @@ def autotvm_get_tuning_tasks(
530523
"""
531524
target, target_host = Target.canon_target_and_host(target, target_host)
532525

533-
if alter_layout:
534-
mod = convert_graph_layout(mod, alter_layout)
526+
mod = apply_graph_transforms(mod, transform_args)
535527

536528
tasks = autotvm.task.extract_from_program(
537529
mod["main"],
@@ -547,7 +539,7 @@ def autoscheduler_get_tuning_tasks(
547539
params: Dict[str, tvm.nd.NDArray],
548540
target: str,
549541
target_host: Optional[str] = None,
550-
alter_layout: Optional[str] = None,
542+
transform_args: Optional[Dict[str, Any]] = None,
551543
hardware_params: Optional[HardwareParams] = None,
552544
include_simple_tasks: bool = False,
553545
):
@@ -563,10 +555,8 @@ def autoscheduler_get_tuning_tasks(
563555
The compilation target.
564556
target_host : str, optional
565557
The compilation target for the host.
566-
alter_layout : str, optional
567-
The layout to convert the graph to. Note, the convert layout
568-
pass doesn't currently guarantee the whole of the graph will
569-
be converted to the chosen layout.
558+
transform_args: dict, optional
559+
Graph transformation arguments that are applied to the relay module.
570560
hardware_params : Optional[HardwareParams]
571561
Hardware parameters used for the search tasks
572562
@@ -579,8 +569,7 @@ def autoscheduler_get_tuning_tasks(
579569
"""
580570
target, target_host = Target.canon_target_and_host(target, target_host)
581571

582-
if alter_layout:
583-
mod = convert_graph_layout(mod, alter_layout)
572+
mod = apply_graph_transforms(mod, transform_args)
584573

585574
# Extract the tasks
586575
tasks, task_weights = auto_scheduler.extract_tasks(

python/tvm/driver/tvmc/compiler.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from .target import target_from_cli, generate_target_args, reconstruct_target_args
3838
from .pass_config import parse_configs
3939
from .pass_list import parse_pass_list_str
40-
from .transform import convert_graph_layout
40+
from .transform import generate_transform_args, parse_graph_transform_args, apply_graph_transforms
4141
from .shape_parser import parse_shape_string
4242
from .workspace_pools import generate_workspace_pools_args, workspace_pools_recombobulate
4343

@@ -61,12 +61,7 @@ def add_compile_parser(subparsers, _, json_params):
6161
default="",
6262
help="the cross compiler options to generate target libraries, e.g. '-mfpu=neon-vfpv4'.",
6363
)
64-
parser.add_argument(
65-
"--desired-layout",
66-
choices=["NCHW", "NHWC"],
67-
default=None,
68-
help="change the data layout of the whole graph.",
69-
)
64+
generate_transform_args(parser)
7065
parser.add_argument(
7166
"--dump-code",
7267
metavar="FORMAT",
@@ -177,6 +172,7 @@ def drive_compile(args):
177172

178173
additional_targets = reconstruct_target_args(args)
179174
workspace_pools_target, extra_targets = target_from_cli(args.target, additional_targets)
175+
transform_args = parse_graph_transform_args(args)
180176

181177
compile_model(
182178
tvmc_model,
@@ -191,7 +187,7 @@ def drive_compile(args):
191187
output_format=args.output_format,
192188
dump_code=dump_code,
193189
target_host=None,
194-
desired_layout=args.desired_layout,
190+
transform_args=transform_args,
195191
disabled_pass=args.disabled_pass,
196192
pass_context_configs=args.pass_config,
197193
mod_name=args.module_name,
@@ -217,7 +213,7 @@ def compile_model(
217213
output_format: str = "so",
218214
dump_code: Optional[List[str]] = None,
219215
target_host: Optional[str] = None,
220-
desired_layout: Optional[str] = None,
216+
transform_args: Optional[Dict[str, Any]] = None,
221217
disabled_pass: Optional[str] = None,
222218
pass_context_configs: Optional[List[str]] = None,
223219
additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None,
@@ -260,10 +256,8 @@ def compile_model(
260256
target_host : str, optional
261257
The target of the host machine if host-side code
262258
needs to be generated.
263-
desired_layout: str, optional
264-
The layout to convert the graph to. Note, the convert layout
265-
pass doesn't currently guarantee the whole of the graph will
266-
be converted to the chosen layout.
259+
transform_args: dict, optional
260+
Graph transformation arguments that are applied to the relay module.
267261
disabled_pass: str, optional
268262
Comma-separated list of passes which needs to be disabled
269263
during compilation
@@ -310,8 +304,7 @@ def compile_model(
310304
disabled_pass=disabled_pass,
311305
instruments=instruments,
312306
):
313-
if desired_layout:
314-
mod = convert_graph_layout(mod, desired_layout)
307+
mod = apply_graph_transforms(mod, transform_args)
315308

316309
for partition_function, opts in zip(partition_functions, partition_opts):
317310
mod = partition_function(mod, params, mod_name=mod_name, **opts)

0 commit comments

Comments
 (0)