Skip to content

Commit 1949d26

Browse files
committed
* review comments.
1 parent 670c523 commit 1949d26

File tree

5 files changed

+102
-75
lines changed

5 files changed

+102
-75
lines changed

python/tvm/driver/tvmc/autotuner.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
# pylint: disable=unused-argument
1718
"""
1819
Provides support to auto-tuning networks using AutoTVM.
1920
"""
@@ -280,7 +281,6 @@ def drive_tune(args):
280281
tuner=args.tuner,
281282
min_repeat_ms=args.min_repeat_ms,
282283
early_stopping=args.early_stopping,
283-
transform_args=transform_args,
284284
timeout=args.timeout,
285285
repeat=args.repeat,
286286
number=args.number,
@@ -289,6 +289,7 @@ def drive_tune(args):
289289
include_simple_tasks=args.include_simple_tasks,
290290
log_estimated_latency=args.log_estimated_latency,
291291
additional_target_options=reconstruct_target_args(args),
292+
**transform_args,
292293
)
293294

294295

@@ -306,7 +307,6 @@ def tune_model(
306307
tuner: str = "xgb",
307308
min_repeat_ms: Optional[int] = None,
308309
early_stopping: Optional[int] = None,
309-
transform_args: Optional[Dict[str, Any]] = None,
310310
timeout: int = 10,
311311
repeat: int = 1,
312312
number: int = 10,
@@ -315,6 +315,12 @@ def tune_model(
315315
include_simple_tasks: bool = False,
316316
log_estimated_latency: bool = False,
317317
additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None,
318+
desired_layout: Optional[str] = None,
319+
desired_layout_ops: Optional[List[str]] = None,
320+
mixed_precision: bool = False,
321+
mixed_precision_ops: Optional[List[str]] = None,
322+
mixed_precision_calculation_type: Optional[str] = None,
323+
mixed_precision_acc_type: Optional[str] = None,
318324
):
319325
"""Use tuning to automatically optimize the functions in a model.
320326
@@ -371,12 +377,28 @@ def tune_model(
371377
If using the autoscheduler, write the estimated latency at each step of tuning to file.
372378
additional_target_options: Optional[Dict[str, Dict[str, Any]]]
373379
Additional target options in a dictionary to combine with initial Target arguments
380+
desired_layout: str, optional
381+
Can be one of "NCHW" or "NHWC". When specified, compatible operations in the graph
382+
will have their layout set to this format. Tasks will then be tuned using this
383+
specified layout.
384+
desired_layout_ops: list[str], optional
385+
The list of operators to be transformed with desired layout.
386+
mixed_precision: bool
387+
To enable mixed precision transformation.
388+
mixed_precision_ops: list[str], optional
389+
The list of operators to be converted to mixed precision.
390+
mixed_precision_calculation_type: str
391+
The calculation dtype to be used while mixed precision.
392+
mixed_precision_acc_type: str
393+
The accumulation data type to be used while mixed precision.
394+
374395
375396
Returns
376397
-------
377398
tuning_records : str
378399
The path to the produced tuning log file.
379400
"""
401+
transform_args = parse_graph_transform_args(locals())
380402
target, extra_targets = target_from_cli(target, additional_target_options)
381403
target, target_host = Target.canon_target_and_host(target, target_host)
382404
# TODO(jwfromm) Remove this deepcopy once AlterOpLayout bug that mutates source

python/tvm/driver/tvmc/compiler.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
# pylint: disable=unused-argument
1718
"""
1819
Provides support to compile networks both AOT and JIT.
1920
"""
@@ -187,14 +188,14 @@ def drive_compile(args):
187188
output_format=args.output_format,
188189
dump_code=dump_code,
189190
target_host=None,
190-
transform_args=transform_args,
191191
disabled_pass=args.disabled_pass,
192192
pass_context_configs=args.pass_config,
193193
mod_name=args.module_name,
194194
additional_target_options=additional_targets,
195195
workspace_pools=(
196196
workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets)
197197
),
198+
**transform_args,
198199
)
199200

200201
return 0
@@ -213,14 +214,19 @@ def compile_model(
213214
output_format: str = "so",
214215
dump_code: Optional[List[str]] = None,
215216
target_host: Optional[str] = None,
216-
transform_args: Optional[Dict[str, Any]] = None,
217217
disabled_pass: Optional[str] = None,
218218
pass_context_configs: Optional[List[str]] = None,
219219
additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None,
220220
use_vm: bool = False,
221221
mod_name: Optional[str] = "default",
222222
workspace_pools: Optional[WorkspaceMemoryPools] = None,
223223
instruments: Optional[Sequence[PassInstrument]] = None,
224+
desired_layout: Optional[str] = None,
225+
desired_layout_ops: Optional[List[str]] = None,
226+
mixed_precision: bool = False,
227+
mixed_precision_ops: Optional[List[str]] = None,
228+
mixed_precision_calculation_type: Optional[str] = None,
229+
mixed_precision_acc_type: Optional[str] = None,
224230
):
225231
"""Compile a model from a supported framework into a TVM module.
226232
@@ -256,8 +262,6 @@ def compile_model(
256262
target_host : str, optional
257263
The target of the host machine if host-side code
258264
needs to be generated.
259-
transform_args: dict, optional
260-
Graph transformation arguments that are applied to the relay module.
261265
disabled_pass: str, optional
262266
Comma-separated list of passes which needs to be disabled
263267
during compilation
@@ -275,6 +279,20 @@ def compile_model(
275279
compilation.
276280
instruments: Optional[Sequence[PassInstrument]]
277281
The list of pass instrument implementations.
282+
desired_layout: str, optional
283+
Can be one of "NCHW" or "NHWC". When specified, compatible operations in the graph
284+
will have their layout set to this format. Tasks will then be tuned using this
285+
specified layout.
286+
desired_layout_ops: list[str], optional
287+
The list of operators to be transformed with desired layout.
288+
mixed_precision: bool
289+
To enable mixed precision transformation.
290+
mixed_precision_ops: list[str], optional
291+
The list of operators to be converted to mixed precision.
292+
mixed_precision_calculation_type: str
293+
The calculation dtype to be used while mixed precision.
294+
mixed_precision_acc_type: str
295+
The accumulation data type to be used while mixed precision.
278296
279297
Returns
280298
-------
@@ -304,6 +322,7 @@ def compile_model(
304322
disabled_pass=disabled_pass,
305323
instruments=instruments,
306324
):
325+
transform_args = parse_graph_transform_args(locals())
307326
mod = apply_graph_transforms(mod, transform_args)
308327

309328
for partition_function, opts in zip(partition_functions, partition_opts):

python/tvm/driver/tvmc/transform.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,40 +21,42 @@
2121
from tvm import relay, transform
2222
from tvm.driver.tvmc import TVMCException
2323

24-
# ToMixedPrecision
25-
ACC_DTYPE = "float32"
2624

25+
def generate_mixed_precision_rule(acc_dtype):
26+
def _mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
27+
return [
28+
relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
29+
acc_dtype,
30+
mixed_precision_type,
31+
]
2732

28-
def mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
29-
global ACC_DTYPE
30-
return [
31-
relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
32-
ACC_DTYPE,
33-
mixed_precision_type,
34-
]
33+
return _mixed_precision_rule
3534

3635

3736
class MixedPrecision(object):
3837
"""Temporarily changes attr of ops to enable required precision."""
3938

40-
def __init__(self, ops):
39+
def __init__(self, ops, acc_type):
4140
"""Saves the required info for RAII pattern usage.
4241
4342
Parameters
4443
----------
4544
ops : list
4645
list of operators
46+
acc_type: str
47+
Output or accumulation precision to be used.
4748
"""
4849
self.older_attr = {}
4950
self.ops = ops
51+
self.acc_type = acc_type
5052
self.attr_key = "FTVMMixedPrecisionConversionType"
5153

5254
def __enter__(self):
5355
for op_name in self.ops:
5456
op = relay.op.get(op_name)
5557
self.older_attr[op_name] = op.get_attr(self.attr_key)
5658
op.reset_attr(self.attr_key)
57-
op.set_attr(self.attr_key, mixed_precision_rule)
59+
op.set_attr(self.attr_key, generate_mixed_precision_rule(self.acc_type))
5860
return self
5961

6062
def __exit__(self, ptype, value, trace):
@@ -65,20 +67,18 @@ def __exit__(self, ptype, value, trace):
6567
op.set_attr(self.attr_key, self.older_attr[op_name])
6668

6769

68-
def convert_to_mixed_precision(
69-
mod, ops="nn.conv2d,nn.dense", input_type="float16", out_type="float16"
70-
):
70+
def convert_to_mixed_precision(mod, ops=None, calculation_type="float16", acc_type="float16"):
7171
"""Converts the operator datatypes
7272
7373
Parameters
7474
----------
7575
mod : tvm.IRModule
7676
The relay module to convert.
77-
ops : str
77+
ops : list
7878
List of operators to be precision converted.
79-
input_type: str
79+
calculation_type: str
8080
Input precision to be used.
81-
output_type: str
81+
acc_type: str
8282
Output or accumulation precision to be used.
8383
8484
Returns
@@ -87,10 +87,10 @@ def convert_to_mixed_precision(
8787
The converted module.
8888
"""
8989

90-
global ACC_DTYPE
91-
ACC_DTYPE = out_type
90+
if ops is None:
91+
ops = ["nn.conv2d", "nn.dense"]
9292

93-
with MixedPrecision(ops.split(",")):
93+
with MixedPrecision(ops, acc_type):
9494
seq = transform.Sequential(
9595
[relay.transform.InferType(), relay.transform.ToMixedPrecision()]
9696
)
@@ -103,7 +103,7 @@ def convert_to_mixed_precision(
103103
raise TVMCException("Error converting mixed precision : {0}".format(str(err)))
104104

105105

106-
def convert_graph_layout(mod, desired_layout, ops="nn.conv2d,nn.conv2d_transpose,qnn.conv2d"):
106+
def convert_graph_layout(mod, desired_layout, ops=None):
107107
"""Alter the layout of the input graph.
108108
109109
Parameters
@@ -112,16 +112,18 @@ def convert_graph_layout(mod, desired_layout, ops="nn.conv2d,nn.conv2d_transpose
112112
The relay module to convert.
113113
desired_layout : str
114114
The layout to convert to.
115-
ops : str
115+
ops : list
116116
List of operators to be layout converted.
117117
118118
Returns
119119
-------
120120
mod : tvm.IRModule
121121
The converted module.
122122
"""
123+
if ops is None:
124+
ops = ["nn.conv2d", "nn.conv2d_transpose", "qnn.conv2d"]
123125

124-
desired_layouts = {op: [desired_layout, "default"] for op in ops.split(",")}
126+
desired_layouts = {op: [desired_layout, "default"] for op in ops}
125127

126128
# Convert the layout of the graph where possible.
127129
seq = transform.Sequential(
@@ -164,9 +166,9 @@ def apply_graph_transforms(mod, args):
164166
if args.get("mixed_precision", False):
165167
mod = convert_to_mixed_precision(
166168
mod,
167-
args.get("mixed_precision_ops", "nn.conv2d,nn.dense"),
168-
args.get("mixed_precision_input", "float16"),
169-
args.get("mixed_precision_output", "float16"),
169+
args.get("mixed_precision_ops"),
170+
args.get("mixed_precision_calculation_type"),
171+
args.get("mixed_precision_acc_type"),
170172
)
171173
return mod
172174

@@ -176,26 +178,27 @@ def parse_graph_transform_args(args):
176178
177179
Parameters
178180
----------
179-
args: argparse.Namespace
180-
Arguments from command line parser.
181+
args: argparse.Namespace or dict
182+
Arguments.
181183
182184
Returns
183185
-------
184186
transform_args : dict
185187
Graph transform arguments
186188
"""
187189

188-
args_dict = vars(args)
190+
if not isinstance(args, dict):
191+
args = vars(args)
189192

190193
transform_args = [
191194
"desired_layout",
192195
"desired_layout_ops",
193196
"mixed_precision",
194197
"mixed_precision_ops",
195-
"mixed_precision_input",
196-
"mixed_precision_output",
198+
"mixed_precision_calculation_type",
199+
"mixed_precision_acc_type",
197200
]
198-
transform_args = {key: args_dict.get(key, None) for key in transform_args}
201+
transform_args = {key: args.get(key, None) for key in transform_args}
199202
return transform_args
200203

201204

@@ -211,7 +214,8 @@ def generate_transform_args(parser):
211214
)
212215
parser.add_argument(
213216
"--desired-layout-ops",
214-
default="nn.conv2d,nn.conv2d_transpose,qnn.conv2d",
217+
default=["nn.conv2d", "nn.conv2d_transpose", "qnn.conv2d"],
218+
nargs="+",
215219
help="List of operators to be layout converted.",
216220
)
217221

@@ -223,18 +227,19 @@ def generate_transform_args(parser):
223227
)
224228
parser.add_argument(
225229
"--mixed-precision-ops",
226-
default="nn.conv2d,nn.dense",
230+
default=["nn.conv2d", "nn.dense"],
231+
nargs="+",
227232
help="List of operators to be converted to mixed precision",
228233
)
229234
parser.add_argument(
230-
"--mixed-precision-input",
235+
"--mixed-precision-calculation-type",
231236
choices=["float16", "float32"],
232237
default="float16",
233-
help="Input precision type",
238+
help="Calculation precision type",
234239
)
235240
parser.add_argument(
236-
"--mixed-precision-output",
241+
"--mixed-precision-acc-type",
237242
choices=["float16", "float32"],
238243
default="float16",
239-
help="Output or accumulator precision type",
244+
help="Accumulator precision type",
240245
)

tests/python/driver/tvmc/test_compiler.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def verify_compile_tflite_module(model, shape_dict=None, use_vm=False):
7272
tvmc_model,
7373
target="llvm",
7474
dump_code="ll",
75-
transform_args={"desired_layout": "NCHW"},
75+
desired_layout="NCHW",
7676
use_vm=use_vm,
7777
)
7878
dumps_path = tvmc_package.package_path + ".ll"
@@ -290,9 +290,7 @@ def test_cross_compile_options_aarch64_onnx_module(onnx_resnet50):
290290
def verify_compile_paddle_module(model, shape_dict=None):
291291
pytest.importorskip("paddle")
292292
tvmc_model = tvmc.load(model, "paddle", shape_dict=shape_dict)
293-
tvmc_package = tvmc.compile(
294-
tvmc_model, target="llvm", dump_code="ll", transform_args={"desired_layout": "NCHW"}
295-
)
293+
tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW")
296294
dumps_path = tvmc_package.package_path + ".ll"
297295

298296
# check for output types
@@ -374,7 +372,7 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128):
374372
tvmc_package = tvmc.compile(
375373
tvmc_model,
376374
target="opencl -host=llvm",
377-
transform_args={"desired_layout": "NCHW"},
375+
desired_layout="NCHW",
378376
dump_code="asm",
379377
)
380378
dumps_path = tvmc_package.package_path + ".asm"

0 commit comments

Comments
 (0)