3636namespace tvm {
3737namespace relay {
3838
39+ TVM_REGISTER_PASS_CONFIG_OPTION (" relay.ToMixedPrecision.keep_orig_output_dtype" , Bool);
3940// A callable which hashes std::pair
4041struct pair_hash {
4142 template <class T1 , class T2 >
@@ -105,6 +106,9 @@ class MixedPrecisionPass : public MixedModeMutator {
105106 * encountered. Used for emitting warnings on missing ops in the pass.
106107 */
107108 std::unordered_map<std::string, int > missing_ops_;
109+ const RelayExprNode* root_;
110+ std::vector<DataType> original_dtype_;
111+ bool keep_orig_output_dtype_;
108112
109113 Attrs GetNewAttrs (const CallNode* call, const DataType& accumulation_dtype) const {
110114 /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
@@ -278,8 +282,23 @@ class MixedPrecisionPass : public MixedModeMutator {
278282 public:
279283 using MixedModeMutator::VisitExpr_;
280284
281- explicit MixedPrecisionPass (DataType mixed_precision_type = DataType::Float(16 ))
282- : MixedModeMutator(), mixed_precision_type_(mixed_precision_type) {
285+ explicit MixedPrecisionPass (Expr base, bool keep_orig_output_dtype,
286+ DataType mixed_precision_type = DataType::Float(16 ))
287+ : MixedModeMutator(),
288+ mixed_precision_type_(mixed_precision_type),
289+ root_(Downcast<Function>(base)->body.get()),
290+ keep_orig_output_dtype_(keep_orig_output_dtype) {
291+ if (keep_orig_output_dtype_) {
292+ if (root_->IsInstance <tvm::relay::TupleNode>()) {
293+ const TupleTypeNode* tuple_type = (root_->checked_type_ ).as <TupleTypeNode>();
294+ for (Type t : tuple_type->fields ) {
295+ const TensorTypeNode* tensor_type = t.as <TensorTypeNode>();
296+ original_dtype_.push_back (tensor_type->dtype );
297+ }
298+ } else if (root_->IsInstance <tvm::relay::CallNode>()) {
299+ original_dtype_.push_back ((root_->checked_type_ ).as <TensorTypeNode>()->dtype );
300+ }
301+ }
283302 if (!mixed_precision_type_.is_float () && !mixed_precision_type_.is_bfloat16 ()) {
284303 LOG (FATAL) << " Only support IEEE floating point mixed precision types and bfloat16, but got "
285304 << mixed_precision_type_;
@@ -381,6 +400,11 @@ class MixedPrecisionPass : public MixedModeMutator {
381400 if (accumulation_dtype != output_dtype) {
382401 output = CastArg (output, GetType (output), output_dtype);
383402 }
403+ if (pre_call_node == root_ && keep_orig_output_dtype_) {
404+ if (original_dtype_[0 ] != output_dtype) {
405+ output = CastArg (output, GetType (output), original_dtype_[0 ]);
406+ }
407+ }
384408 return output;
385409 }
386410
@@ -396,6 +420,21 @@ class MixedPrecisionPass : public MixedModeMutator {
396420 Expr Rewrite_ (const TupleNode* pre , const Expr& post ) {
397421 // The old checked type in the expression may not be valid so clear it
398422 post ->checked_type_ = Type (nullptr );
423+ if (pre == root_ && keep_orig_output_dtype_) {
424+ Array<Expr> new_expr;
425+ bool all_same = true ;
426+ for (size_t i = 0 ; i < original_dtype_.size (); i++) {
427+ Expr output_element = GetField (post , i);
428+ Expr casted_element;
429+ auto output_element_type = transform::InferTypeLocal (output_element);
430+ casted_element = CastArg (output_element, output_element_type, original_dtype_[i]);
431+ new_expr.push_back (casted_element);
432+ all_same &= casted_element.same_as (output_element);
433+ }
434+ if (!all_same) {
435+ return Tuple (new_expr);
436+ }
437+ }
399438 return post ;
400439 }
401440
@@ -421,11 +460,12 @@ class MixedPrecisionPass : public MixedModeMutator {
421460 }
422461
423462 // To access map of ops not registered for error reporting
424- friend Expr ToMixedPrecision (const Expr& expr, const DataType& mixed_precision_type ,
425- int missing_op_mode);
463+ friend Expr ToMixedPrecision (const Expr& expr, bool keep_orig_output_dtype ,
464+ const DataType& mixed_precision_type, int missing_op_mode);
426465};
427466
428- Expr ToMixedPrecision (const Expr& expr, const DataType& mixed_precision_type, int missing_op_mode) {
467+ Expr ToMixedPrecision (const Expr& expr, bool keep_orig_output_dtype,
468+ const DataType& mixed_precision_type, int missing_op_mode) {
429469 /*
430470 missing_op_mode:
431471
@@ -436,7 +476,8 @@ Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, in
436476 ICHECK (missing_op_mode >= 0 && missing_op_mode <= 2 )
437477 << " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode;
438478
439- MixedPrecisionPass converter = MixedPrecisionPass (mixed_precision_type);
479+ MixedPrecisionPass converter =
480+ MixedPrecisionPass (expr, keep_orig_output_dtype, mixed_precision_type);
440481 auto result = converter.Mutate (expr);
441482
442483 for (auto it = converter.missing_ops_ .begin ();
@@ -460,7 +501,12 @@ namespace transform {
460501Pass ToMixedPrecision (DataType mixed_precision_type, int missing_op_mode) {
461502 runtime::TypedPackedFunc<Function (Function, IRModule, PassContext)> pass_func =
462503 [=](Function f, IRModule m, PassContext pc) {
463- return Downcast<Function>(ToMixedPrecision (f, mixed_precision_type, missing_op_mode));
504+ bool keep_orig_output_dtype = false ;
505+ keep_orig_output_dtype = pc->GetConfig (" relay.ToMixedPrecision.keep_orig_output_dtype" ,
506+ Bool (keep_orig_output_dtype))
507+ .value ();
508+ return Downcast<Function>(
509+ ToMixedPrecision (f, keep_orig_output_dtype, mixed_precision_type, missing_op_mode));
464510 };
465511 return CreateFunctionPass (pass_func, 0 , " ToMixedPrecision" , {});
466512}
0 commit comments