Skip to content

Commit b3d01c2

Browse files
authored
[Relax][Bugfix] Preserve dtype in ToMixedPrecision for kNever ops (#17263)
Prior to this commit, while an operator with the `MixedPrecisionPolicyKind::kNever` attribute would not be updated from `float32` to `float16`, it would be erroneously updated from `float16` to `float32`. This commit updates `ToMixedPrecision` to preserve the datatype of any arguments used in a `kNever` operation, rather than forcing them to a `float32` datatype.
1 parent bed66d2 commit b3d01c2

File tree

2 files changed

+75
-28
lines changed

2 files changed

+75
-28
lines changed

src/relax/transform/to_mixed_precision.cc

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -303,11 +303,7 @@ class ToMixedPrecisionRewriter : public ExprMutator {
303303
}
304304

305305
Array<Expr> RemapArgs(const Array<Expr>& args) {
306-
Array<Expr> new_args;
307-
for (const auto& arg : args) {
308-
new_args.push_back(VarReplacer::Replace(arg, var_remap_));
309-
}
310-
return new_args;
306+
return args.Map([this](Expr arg) { return VarReplacer::Replace(arg, var_remap_); });
311307
}
312308

313309
// Util function to rewrite the expr to the given dtype
@@ -475,37 +471,60 @@ class ToMixedPrecisionRewriter : public ExprMutator {
475471
ReEmitBinding(binding, call_node->args[0]);
476472
return;
477473
}
478-
DataType to;
479-
ObjectPtr<CallNode> new_call = make_object<CallNode>(*call_node);
474+
475+
Call new_call = GetRef<Call>(call_node);
476+
480477
// We first to remap the args to the current vars according to the var_remap_
481-
new_call->args = std::move(RemapArgs(call_node->args));
478+
new_call.CopyOnWrite()->args = RemapArgs(new_call->args);
479+
482480
// Then we rewrite the args according to the policy
481+
std::optional<DataType> opt_new_dtype = std::nullopt;
482+
483483
if (policy == kAlways) {
484-
to = fp16_;
484+
opt_new_dtype = fp16_;
485485
auto attr_map = Op::GetAttrMap<FInferMixedPrecision>("FInferMixedPrecision");
486486
ICHECK(attr_map.count(op));
487-
auto f = attr_map[op];
488-
new_call = make_object<CallNode>(*(f(Call(new_call), output_dtype_).get()));
487+
new_call = attr_map[op](new_call, output_dtype_);
489488
} else if (policy == kFollow) {
490-
to = AllFP16Castable(new_call->args) ? fp16_ : fp32_;
489+
opt_new_dtype = AllFP16Castable(new_call->args) ? fp16_ : fp32_;
491490
} else if (policy == kNever) {
492-
to = fp32_;
491+
// An upstream operation may have changed the datatype of the
492+
// arguments. Because this operation must be provided with
493+
// exactly the same dtype as it previously had, it may require a
494+
// cast back to the original datatype.
495+
496+
if (!new_call->args.same_as(call_node->args)) {
497+
Array<Expr> new_typed_args;
498+
for (size_t i = 0; i < call_node->args.size(); i++) {
499+
auto arg = new_call->args[i];
500+
auto old_ntype = NTypeFrom(call_node->args[i]);
501+
new_typed_args.push_back(RewriteExpr(arg, old_ntype));
502+
}
503+
new_call.CopyOnWrite()->args = new_typed_args;
504+
}
505+
493506
} else {
494507
LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy;
495508
}
496-
new_call->args = std::move(RewriteArgs(new_call->args, to));
497-
new_call->struct_info_ = NullOpt;
498-
Expr new_value = builder_->Normalize(Call(new_call));
499-
if (policy == kAlways && binding->var->IsInstance<DataflowVarNode>()) {
500-
// kAlways: store the tensors to fp16
501-
// But global vars will be stored to the original dtype anyway (see below)
502-
new_value = RewriteExpr(new_value, NTypeFrom(new_value, fp16_));
503-
}
504-
if (!binding->var->IsInstance<DataflowVarNode>()) {
505-
// Global var: store the tensors to the original dtype
506-
NType to = NTypeFrom(binding->var);
507-
new_value = RewriteExpr(new_value, to);
509+
510+
Expr new_value = new_call;
511+
if (opt_new_dtype) {
512+
auto new_dtype = opt_new_dtype.value();
513+
new_call.CopyOnWrite()->args = RewriteArgs(new_call->args, new_dtype);
514+
new_call.CopyOnWrite()->struct_info_ = NullOpt;
515+
516+
new_value = builder_->Normalize(Call(new_call));
517+
518+
if (!binding->var->IsInstance<DataflowVarNode>()) {
519+
// Non-Dataflow var: store the tensors to the original dtype
520+
new_value = RewriteExpr(new_value, NTypeFrom(binding->var));
521+
} else if (policy == kAlways && binding->var->IsInstance<DataflowVarNode>()) {
522+
// kAlways: store the tensors to fp16
523+
// But non-dataflow vars will be stored to the original dtype anyway (see above)
524+
new_value = RewriteExpr(new_value, NTypeFrom(new_value, new_dtype));
525+
}
508526
}
527+
509528
ReEmitBinding(binding, builder_->Normalize(new_value));
510529
}
511530

tests/python/relax/test_transform_to_mixed_precision.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tvm import relax
2121
import tvm.testing
2222
from tvm.relax.transform import ToMixedPrecision
23-
from tvm.script.parser import ir as I, relax as R
23+
from tvm.script.parser import ir as I, relax as R, tir as T
2424

2525

2626
def _assert_test(input, expected=None, expected2=None):
@@ -614,8 +614,8 @@ def main(
614614
x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 3, 3, 3), "float32")
615615
) -> R.Tensor(None, "float32", ndim=4):
616616
with R.dataflow():
617-
gv: R.Tensor((2, 3, 26, 26), "float32") = R.nn.conv2d(x, w, padding=(1, 1))
618-
gv1: R.Tensor((2, 3, 26, 26), "float32") = R.nn.softmax(x, axis=1)
617+
gv: R.Tensor((2, 3, 28, 28), "float32") = R.nn.conv2d(x, w, padding=(1, 1))
618+
gv1: R.Tensor((2, 3, 28, 28), "float32") = R.nn.softmax(x, axis=1)
619619
gv2 = R.add(gv, gv1)
620620
R.output(gv2)
621621
return gv2
@@ -1036,5 +1036,33 @@ def main(
10361036
tvm.ir.assert_structural_equal(mod, Expected)
10371037

10381038

1039+
def test_call_tir_with_float16_args():
1040+
@I.ir_module
1041+
class Before:
1042+
@R.function
1043+
def main(A: R.Tensor([64], "float16")):
1044+
cls = Before
1045+
with R.dataflow():
1046+
B = R.call_tir(cls.tir_identity, [A], out_sinfo=R.Tensor([64], "float16"))
1047+
C = R.call_tir(cls.tir_identity, [B], out_sinfo=R.Tensor([64], "float16"))
1048+
R.output(C)
1049+
return C
1050+
1051+
@T.prim_func
1052+
def tir_identity(
1053+
Input: T.Buffer(64, "float16"),
1054+
Output: T.Buffer(64, "float16"),
1055+
):
1056+
for i in range(64):
1057+
with T.block("copy"):
1058+
vi = T.axis.remap("S", [i])
1059+
Output[vi] = Input[vi]
1060+
1061+
Expected = Before
1062+
1063+
After = ToMixedPrecision()(Before)
1064+
tvm.ir.assert_structural_equal(Expected, After)
1065+
1066+
10391067
if __name__ == "__main__":
10401068
tvm.testing.main()

0 commit comments

Comments
 (0)