Skip to content

Commit a9ac14b

Browse files
committed
[TVMC][TRANSFORMS] ToMixedPrecision transform support with custom options enabled
Adds new command line options --mixed-precision --mixed-precision-ops --mixed-precision-input --mixed-precision-output This PR also enhances the python interface by replacing alter_layout to transform_args. transform_args is a dict with all tranform related options including existing alter_layout option.
1 parent e7ad4bc commit a9ac14b

File tree

4 files changed

+238
-47
lines changed

4 files changed

+238
-47
lines changed

python/tvm/driver/tvmc/autotuner.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from .model import TVMCModel
4040
from .target import target_from_cli, generate_target_args, reconstruct_target_args
4141
from .shape_parser import parse_shape_string
42-
from .transform import convert_graph_layout
42+
from .transform import generate_transform_args, parse_graph_transform_args, apply_graph_transforms
4343

4444

4545
# pylint: disable=invalid-name
@@ -127,12 +127,7 @@ def add_tune_parser(subparsers, _, json_params):
127127
metavar="PATH",
128128
help="path to an auto-tuning log file by AutoTVM.",
129129
)
130-
parser.add_argument(
131-
"--desired-layout",
132-
choices=["NCHW", "NHWC"],
133-
default=None,
134-
help="change the data layout of the whole graph",
135-
)
130+
generate_transform_args(parser)
136131
parser.add_argument(
137132
"--enable-autoscheduler",
138133
help="enable tuning the graph through the AutoScheduler tuner",
@@ -269,6 +264,8 @@ def drive_tune(args):
269264
rpc_hostname = None
270265
rpc_port = None
271266

267+
transform_args = parse_graph_transform_args(args)
268+
272269
tune_model(
273270
tvmc_model,
274271
args.target,
@@ -283,7 +280,7 @@ def drive_tune(args):
283280
tuner=args.tuner,
284281
min_repeat_ms=args.min_repeat_ms,
285282
early_stopping=args.early_stopping,
286-
desired_layout=args.desired_layout,
283+
transform_args=transform_args,
287284
timeout=args.timeout,
288285
repeat=args.repeat,
289286
number=args.number,
@@ -309,7 +306,7 @@ def tune_model(
309306
tuner: str = "xgb",
310307
min_repeat_ms: Optional[int] = None,
311308
early_stopping: Optional[int] = None,
312-
desired_layout: Optional[str] = None,
309+
transform_args: Optional[Dict[str, Any]] = None,
313310
timeout: int = 10,
314311
repeat: int = 1,
315312
number: int = 10,
@@ -354,10 +351,8 @@ def tune_model(
354351
Minimum time to run each trial. Defaults to 0 on x86 and 1000 on other targets.
355352
early_stopping : int, optional
356353
When specified, stop tuning after this number of trials if results aren't improving.
357-
desired_layout : str, optional
358-
Can be one of "NCHW" or "NHWC". When specified, compatible operations in the graph
359-
will have their layout set to this format. Tasks will then be tuned using this
360-
specified layout.
354+
transform_args: dict, optional
355+
Graph transformation arguments that are applied to the relay module.
361356
timeout : int, optional,
362357
If a kernel trial lasts longer than this duration in seconds, it will be
363358
considered a failure.
@@ -453,7 +448,7 @@ def tune_model(
453448
mod=mod,
454449
params=params,
455450
target=target,
456-
alter_layout=desired_layout,
451+
transform_args=transform_args,
457452
hardware_params=hardware_params,
458453
include_simple_tasks=include_simple_tasks,
459454
)
@@ -475,7 +470,7 @@ def tune_model(
475470
mod=mod,
476471
params=params,
477472
target=target,
478-
alter_layout=desired_layout,
473+
transform_args=transform_args,
479474
)
480475

481476
# In autotvm, trials is specified per task. We can convert the per-model input
@@ -504,7 +499,7 @@ def autotvm_get_tuning_tasks(
504499
params: Dict[str, tvm.nd.NDArray],
505500
target: str,
506501
target_host: Optional[str] = None,
507-
alter_layout: Optional[str] = None,
502+
transform_args: Optional[Dict[str, Any]] = None,
508503
):
509504
"""Get the autotvm tuning tasks for a given relay module.
510505
@@ -518,10 +513,8 @@ def autotvm_get_tuning_tasks(
518513
The compilation target.
519514
target_host : str, optional
520515
The compilation target for the host.
521-
alter_layout : str, optional
522-
The layout to convert the graph to. Note, the convert layout
523-
pass doesn't currently guarantee the whole of the graph will
524-
be converted to the chosen layout.
516+
transform_args: dict, optional
517+
Graph transformation arguments that are applied to the relay module.
525518
526519
Returns
527520
-------
@@ -530,8 +523,7 @@ def autotvm_get_tuning_tasks(
530523
"""
531524
target, target_host = Target.canon_target_and_host(target, target_host)
532525

533-
if alter_layout:
534-
mod = convert_graph_layout(mod, alter_layout)
526+
mod = apply_graph_transforms(mod, transform_args)
535527

536528
tasks = autotvm.task.extract_from_program(
537529
mod["main"],
@@ -547,7 +539,7 @@ def autoscheduler_get_tuning_tasks(
547539
params: Dict[str, tvm.nd.NDArray],
548540
target: str,
549541
target_host: Optional[str] = None,
550-
alter_layout: Optional[str] = None,
542+
transform_args: Optional[Dict[str, Any]] = None,
551543
hardware_params: Optional[HardwareParams] = None,
552544
include_simple_tasks: bool = False,
553545
):
@@ -563,10 +555,8 @@ def autoscheduler_get_tuning_tasks(
563555
The compilation target.
564556
target_host : str, optional
565557
The compilation target for the host.
566-
alter_layout : str, optional
567-
The layout to convert the graph to. Note, the convert layout
568-
pass doesn't currently guarantee the whole of the graph will
569-
be converted to the chosen layout.
558+
transform_args: dict, optional
559+
Graph transformation arguments that are applied to the relay module.
570560
hardware_params : Optional[HardwareParams]
571561
Hardware parameters used for the search tasks
572562
@@ -579,8 +569,7 @@ def autoscheduler_get_tuning_tasks(
579569
"""
580570
target, target_host = Target.canon_target_and_host(target, target_host)
581571

582-
if alter_layout:
583-
mod = convert_graph_layout(mod, alter_layout)
572+
mod = apply_graph_transforms(mod, transform_args)
584573

585574
# Extract the tasks
586575
tasks, task_weights = auto_scheduler.extract_tasks(

python/tvm/driver/tvmc/compiler.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from .target import target_from_cli, generate_target_args, reconstruct_target_args
3838
from .pass_config import parse_configs
3939
from .pass_list import parse_pass_list_str
40-
from .transform import convert_graph_layout
40+
from .transform import generate_transform_args, parse_graph_transform_args, apply_graph_transforms
4141
from .shape_parser import parse_shape_string
4242
from .workspace_pools import generate_workspace_pools_args, workspace_pools_recombobulate
4343

@@ -61,12 +61,7 @@ def add_compile_parser(subparsers, _, json_params):
6161
default="",
6262
help="the cross compiler options to generate target libraries, e.g. '-mfpu=neon-vfpv4'.",
6363
)
64-
parser.add_argument(
65-
"--desired-layout",
66-
choices=["NCHW", "NHWC"],
67-
default=None,
68-
help="change the data layout of the whole graph.",
69-
)
64+
generate_transform_args(parser)
7065
parser.add_argument(
7166
"--dump-code",
7267
metavar="FORMAT",
@@ -177,6 +172,7 @@ def drive_compile(args):
177172

178173
additional_targets = reconstruct_target_args(args)
179174
workspace_pools_target, extra_targets = target_from_cli(args.target, additional_targets)
175+
transform_args = parse_graph_transform_args(args)
180176

181177
compile_model(
182178
tvmc_model,
@@ -191,7 +187,7 @@ def drive_compile(args):
191187
output_format=args.output_format,
192188
dump_code=dump_code,
193189
target_host=None,
194-
desired_layout=args.desired_layout,
190+
transform_args=transform_args,
195191
disabled_pass=args.disabled_pass,
196192
pass_context_configs=args.pass_config,
197193
mod_name=args.module_name,
@@ -217,7 +213,7 @@ def compile_model(
217213
output_format: str = "so",
218214
dump_code: Optional[List[str]] = None,
219215
target_host: Optional[str] = None,
220-
desired_layout: Optional[str] = None,
216+
transform_args: Optional[Dict[str, Any]] = None,
221217
disabled_pass: Optional[str] = None,
222218
pass_context_configs: Optional[List[str]] = None,
223219
additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None,
@@ -260,10 +256,8 @@ def compile_model(
260256
target_host : str, optional
261257
The target of the host machine if host-side code
262258
needs to be generated.
263-
desired_layout: str, optional
264-
The layout to convert the graph to. Note, the convert layout
265-
pass doesn't currently guarantee the whole of the graph will
266-
be converted to the chosen layout.
259+
transform_args: dict, optional
260+
Graph transformation arguments that are applied to the relay module.
267261
disabled_pass: str, optional
268262
Comma-separated list of passes which needs to be disabled
269263
during compilation
@@ -310,8 +304,7 @@ def compile_model(
310304
disabled_pass=disabled_pass,
311305
instruments=instruments,
312306
):
313-
if desired_layout:
314-
mod = convert_graph_layout(mod, desired_layout)
307+
mod = apply_graph_transforms(mod, transform_args)
315308

316309
for partition_function, opts in zip(partition_functions, partition_opts):
317310
mod = partition_function(mod, params, mod_name=mod_name, **opts)

python/tvm/driver/tvmc/transform.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,76 @@
1313
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
1414
# KIND, either express or implied. See the License for the
1515
# specific language
16+
# pylint: disable=unused-argument
1617
"""
1718
TVMC Graph Transforms
1819
"""
1920

2021
from tvm import relay, transform
2122
from tvm.driver.tvmc import TVMCException
2223

24+
# ToMixedPrecision
25+
ACC_DTYPE = "float32"
26+
27+
28+
def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
29+
global ACC_DTYPE
30+
return [
31+
relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
32+
ACC_DTYPE,
33+
mixed_precision_type,
34+
]
35+
36+
37+
class MixedPrecision(object):
38+
"""Temporarily changes attr of ops to enable required precision."""
39+
40+
def __init__(self, ops):
41+
"""Saves the required info for RAII pattern usage.
42+
43+
Parameters
44+
----------
45+
ops : list
46+
list of operators
47+
"""
48+
self.older_attr = {}
49+
self.ops = ops
50+
self.attr_key = "FTVMMixedPrecisionConversionType"
51+
52+
def __enter__(self):
53+
for op_name in self.ops:
54+
op = relay.op.get(op_name)
55+
self.older_attr[op_name] = op.get_attr(self.attr_key)
56+
op.reset_attr(self.attr_key)
57+
op.set_attr(self.attr_key, mixed_precision_rule)
58+
return self
59+
60+
def __exit__(self, ptype, value, trace):
61+
for op_name in self.ops:
62+
op = relay.op.get(op_name)
63+
op.reset_attr(self.attr_key)
64+
if self.older_attr[op_name]:
65+
op.set_attr(self.attr_key, self.older_attr[op_name])
66+
67+
68+
def convert_to_mixed_precision(mod, ops, input_type, out_type):
69+
"""Converts the operator datatypes"""
70+
71+
global ACC_DTYPE
72+
ACC_DTYPE = out_type
73+
74+
with MixedPrecision(ops.split(",")):
75+
seq = transform.Sequential(
76+
[relay.transform.InferType(), relay.transform.ToMixedPrecision()]
77+
)
78+
with transform.PassContext(
79+
config={"relay.ToMixedPrecision.keep_orig_output_dtype": True}, opt_level=3
80+
):
81+
try:
82+
return seq(mod)
83+
except Exception as err:
84+
raise TVMCException("Error converting mixed precision : {0}".format(str(err)))
85+
2386

2487
def convert_graph_layout(mod, desired_layout):
2588
"""Alter the layout of the input graph.
@@ -58,3 +121,99 @@ def convert_graph_layout(mod, desired_layout):
58121
return seq(mod)
59122
except Exception as err:
60123
raise TVMCException("Error converting layout to {0}: {1}".format(desired_layout, str(err)))
124+
125+
126+
def apply_graph_transforms(mod, args):
127+
"""Alter the layout of the input graph.
128+
129+
Parameters
130+
----------
131+
mod : tvm.IRModule
132+
The relay module to convert.
133+
args : dict
134+
The transform arguments.
135+
136+
Returns
137+
-------
138+
mod : tvm.IRModule
139+
The converted module.
140+
"""
141+
if not args:
142+
return mod
143+
144+
# AlterLayout
145+
if args.get("desired_layout", False):
146+
mod = convert_graph_layout(mod, args["desired_layout"])
147+
148+
# ToMixedPrecision
149+
if args.get("mixed_precision", False):
150+
mod = convert_to_mixed_precision(
151+
mod,
152+
args.get("mixed_precision_ops", "nn.conv2d,nn.dense"),
153+
args.get("mixed_precision_input", "float16"),
154+
args.get("mixed_precision_output", "float16"),
155+
)
156+
return mod
157+
158+
159+
def parse_graph_transform_args(args):
160+
"""Parse incoming options for graph transform arguments.
161+
162+
Parameters
163+
----------
164+
args: argparse.Namespace
165+
Arguments from command line parser.
166+
167+
Returns
168+
-------
169+
transform_args : dict
170+
Graph transform arguments
171+
"""
172+
173+
args_dict = vars(args)
174+
175+
transform_args = [
176+
"desired_layout",
177+
"mixed_precision",
178+
"mixed_precision_ops",
179+
"mixed_precision_input",
180+
"mixed_precision_output",
181+
]
182+
transform_args = {key: args_dict[key] for key in transform_args}
183+
return transform_args
184+
185+
186+
def generate_transform_args(parser):
187+
"""Add graph transform related args"""
188+
189+
# AlterLayout
190+
parser.add_argument(
191+
"--desired-layout",
192+
choices=["NCHW", "NHWC"],
193+
default=None,
194+
help="Change the data layout of the whole graph.",
195+
)
196+
197+
# ToMixedPrecision
198+
parser.add_argument(
199+
"--mixed-precision",
200+
help="Enable mixed precision conversion",
201+
action="store_true",
202+
)
203+
parser.add_argument(
204+
"--mixed-precision-ops",
205+
default="nn.conv2d,nn.dense",
206+
help="List of operators to be converted to mixed precision",
207+
)
208+
parser.add_argument(
209+
"--mixed-precision-input",
210+
choices=["float16", "float32"],
211+
default="float16",
212+
help="Input precision type",
213+
)
214+
parser.add_argument(
215+
"--mixed-precision-output",
216+
choices=["float16", "float32"],
217+
default="float16",
218+
help="Output or accumulator precision type",
219+
)

0 commit comments

Comments
 (0)