@@ -303,11 +303,7 @@ class ToMixedPrecisionRewriter : public ExprMutator {
303303 }
304304
305305 Array<Expr> RemapArgs (const Array<Expr>& args) {
306- Array<Expr> new_args;
307- for (const auto & arg : args) {
308- new_args.push_back (VarReplacer::Replace (arg, var_remap_));
309- }
310- return new_args;
306+ return args.Map ([this ](Expr arg) { return VarReplacer::Replace (arg, var_remap_); });
311307 }
312308
313309 // Util function to rewrite the expr to the given dtype
@@ -475,37 +471,60 @@ class ToMixedPrecisionRewriter : public ExprMutator {
475471 ReEmitBinding (binding, call_node->args [0 ]);
476472 return ;
477473 }
478- DataType to;
479- ObjectPtr<CallNode> new_call = make_object<CallNode>(*call_node);
474+
475+ Call new_call = GetRef<Call>(call_node);
476+
480477 // We first to remap the args to the current vars according to the var_remap_
481- new_call->args = std::move (RemapArgs (call_node->args ));
478+ new_call.CopyOnWrite ()->args = RemapArgs (new_call->args );
479+
482480 // Then we rewrite the args according to the policy
481+ std::optional<DataType> opt_new_dtype = std::nullopt ;
482+
483483 if (policy == kAlways ) {
484- to = fp16_;
484+ opt_new_dtype = fp16_;
485485 auto attr_map = Op::GetAttrMap<FInferMixedPrecision>(" FInferMixedPrecision" );
486486 ICHECK (attr_map.count (op));
487- auto f = attr_map[op];
488- new_call = make_object<CallNode>(*(f (Call (new_call), output_dtype_).get ()));
487+ new_call = attr_map[op](new_call, output_dtype_);
489488 } else if (policy == kFollow ) {
490- to = AllFP16Castable (new_call->args ) ? fp16_ : fp32_;
489+ opt_new_dtype = AllFP16Castable (new_call->args ) ? fp16_ : fp32_;
491490 } else if (policy == kNever ) {
492- to = fp32_;
491+ // An upstream operation may have changed the datatype of the
492+ // arguments. Because this operation must be provided with
493+ // exactly the same dtype as it previously had, it may require a
494+ // cast back to the original datatype.
495+
496+ if (!new_call->args .same_as (call_node->args )) {
497+ Array<Expr> new_typed_args;
498+ for (size_t i = 0 ; i < call_node->args .size (); i++) {
499+ auto arg = new_call->args [i];
500+ auto old_ntype = NTypeFrom (call_node->args [i]);
501+ new_typed_args.push_back (RewriteExpr (arg, old_ntype));
502+ }
503+ new_call.CopyOnWrite ()->args = new_typed_args;
504+ }
505+
493506 } else {
494507 LOG (FATAL) << " Unsupported TMixedPrecisionPolicy: " << policy;
495508 }
496- new_call->args = std::move (RewriteArgs (new_call->args , to));
497- new_call->struct_info_ = NullOpt;
498- Expr new_value = builder_->Normalize (Call (new_call));
499- if (policy == kAlways && binding->var ->IsInstance <DataflowVarNode>()) {
500- // kAlways: store the tensors to fp16
501- // But global vars will be stored to the original dtype anyway (see below)
502- new_value = RewriteExpr (new_value, NTypeFrom (new_value, fp16_));
503- }
504- if (!binding->var ->IsInstance <DataflowVarNode>()) {
505- // Global var: store the tensors to the original dtype
506- NType to = NTypeFrom (binding->var );
507- new_value = RewriteExpr (new_value, to);
509+
510+ Expr new_value = new_call;
511+ if (opt_new_dtype) {
512+ auto new_dtype = opt_new_dtype.value ();
513+ new_call.CopyOnWrite ()->args = RewriteArgs (new_call->args , new_dtype);
514+ new_call.CopyOnWrite ()->struct_info_ = NullOpt;
515+
516+ new_value = builder_->Normalize (Call (new_call));
517+
518+ if (!binding->var ->IsInstance <DataflowVarNode>()) {
519+ // Non-Dataflow var: store the tensors to the original dtype
520+ new_value = RewriteExpr (new_value, NTypeFrom (binding->var ));
521+ } else if (policy == kAlways && binding->var ->IsInstance <DataflowVarNode>()) {
522+ // kAlways: store the tensors to fp16
523+ // But non-dataflow vars will be stored to the original dtype anyway (see above)
524+ new_value = RewriteExpr (new_value, NTypeFrom (new_value, new_dtype));
525+ }
508526 }
527+
509528 ReEmitBinding (binding, builder_->Normalize (new_value));
510529 }
511530
0 commit comments