Skip to content

Commit 09a0ea0

Browse files
srkreddy1238tqchen
authored andcommitted
[FIX][TVMC] Fix the mixed precision conversion pipeline
Fixed the mixed precision conversion pipeline issue.
1 parent ed2c26a commit 09a0ea0

File tree

5 files changed

+10
-4
lines changed

5 files changed

+10
-4
lines changed

python/tvm/driver/tvmc/autotuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def autotvm_get_tuning_tasks(
672672
"""
673673
target, target_host = Target.canon_target_and_host(target, target_host)
674674

675-
mod = apply_graph_transforms(mod, transform_args)
675+
mod = apply_graph_transforms(mod, transform_args, params)
676676

677677
tasks = autotvm.task.extract_from_program(
678678
mod["main"],
@@ -718,7 +718,7 @@ def autoscheduler_get_tuning_tasks(
718718
"""
719719
target, target_host = Target.canon_target_and_host(target, target_host)
720720

721-
mod = apply_graph_transforms(mod, transform_args)
721+
mod = apply_graph_transforms(mod, transform_args, params)
722722

723723
# Extract the tasks
724724
tasks, task_weights = auto_scheduler.extract_tasks(

python/tvm/driver/tvmc/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def compile_model(
401401
instruments=instruments,
402402
):
403403
transform_args = parse_graph_transform_args(locals())
404-
mod = apply_graph_transforms(mod, transform_args)
404+
mod = apply_graph_transforms(mod, transform_args, params)
405405

406406
for partition_function, opts in zip(partition_functions, partition_opts):
407407
mod = partition_function(mod, params, mod_name=mod_name, **opts)

python/tvm/driver/tvmc/transform.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def layout_helper(layout):
162162
raise TVMCException("Error converting layouts: {}".format(str(err)))
163163

164164

165-
def apply_graph_transforms(mod, args):
165+
def apply_graph_transforms(mod, args, params=None):
166166
"""Alter the layout of the input graph.
167167
168168
Parameters
@@ -171,6 +171,8 @@ def apply_graph_transforms(mod, args):
171171
The relay module to convert.
172172
args : dict
173173
The transform arguments.
174+
params: dict
175+
Module params
174176
175177
Returns
176178
-------
@@ -188,6 +190,7 @@ def apply_graph_transforms(mod, args):
188190

189191
# ToMixedPrecision
190192
if args.get("mixed_precision", False):
193+
mod = relay.quantize.prerequisite_optimize(mod, params)
191194
mod = convert_to_mixed_precision(
192195
mod,
193196
args.get("mixed_precision_ops"),

tests/python/driver/tvmc/test_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def check(self, func):
226226
"mixed_precision_calculation_type": "float16",
227227
"mixed_precision_acc_type": "float16",
228228
},
229+
params,
229230
)
230231
ret = CheckOpMutator("float16", "float16", "nn.conv2d").check(mod["main"])
231232
assert ret
@@ -240,6 +241,7 @@ def check(self, func):
240241
"mixed_precision_calculation_type": "float16",
241242
"mixed_precision_acc_type": "float32",
242243
},
244+
params,
243245
)
244246
ret = CheckOpMutator("float16", "float32", "nn.conv2d").check(mod["main"])
245247
assert ret

tests/python/relay/opencl_texture/test_network.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def _test_mobilenet_v1(remote, target, calc_dtype, executor_type, acc_dtype):
4747
"mixed_precision_calculation_type": calc_dtype,
4848
"mixed_precision_acc_type": acc_dtype,
4949
},
50+
params,
5051
)
5152

5253
if executor_type == "ge":

0 commit comments

Comments
 (0)