Skip to content

Commit eae836c

Browse files
authored
Fix mixed precision output type to original type (#11142)
1 parent 5007033 commit eae836c

File tree

2 files changed

+82
-17
lines changed

2 files changed

+82
-17
lines changed

src/relay/transforms/to_mixed_precision.cc

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
namespace tvm {
3737
namespace relay {
3838

39+
TVM_REGISTER_PASS_CONFIG_OPTION("relay.ToMixedPrecision.keep_orig_output_dtype", Bool);
3940
// A callable which hashes std::pair
4041
struct pair_hash {
4142
template <class T1, class T2>
@@ -105,6 +106,9 @@ class MixedPrecisionPass : public MixedModeMutator {
105106
* encountered. Used for emitting warnings on missing ops in the pass.
106107
*/
107108
std::unordered_map<std::string, int> missing_ops_;
109+
const RelayExprNode* root_;
110+
std::vector<DataType> original_dtype_;
111+
bool keep_orig_output_dtype_;
108112

109113
Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
110114
/* If the accumulation dtype is in the attributes make a copy and mutate the field. */
@@ -278,8 +282,23 @@ class MixedPrecisionPass : public MixedModeMutator {
278282
public:
279283
using MixedModeMutator::VisitExpr_;
280284

281-
explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16))
282-
: MixedModeMutator(), mixed_precision_type_(mixed_precision_type) {
285+
explicit MixedPrecisionPass(Expr base, bool keep_orig_output_dtype,
286+
DataType mixed_precision_type = DataType::Float(16))
287+
: MixedModeMutator(),
288+
mixed_precision_type_(mixed_precision_type),
289+
root_(Downcast<Function>(base)->body.get()),
290+
keep_orig_output_dtype_(keep_orig_output_dtype) {
291+
if (keep_orig_output_dtype_) {
292+
if (root_->IsInstance<tvm::relay::TupleNode>()) {
293+
const TupleTypeNode* tuple_type = (root_->checked_type_).as<TupleTypeNode>();
294+
for (Type t : tuple_type->fields) {
295+
const TensorTypeNode* tensor_type = t.as<TensorTypeNode>();
296+
original_dtype_.push_back(tensor_type->dtype);
297+
}
298+
} else if (root_->IsInstance<tvm::relay::CallNode>()) {
299+
original_dtype_.push_back((root_->checked_type_).as<TensorTypeNode>()->dtype);
300+
}
301+
}
283302
if (!mixed_precision_type_.is_float() && !mixed_precision_type_.is_bfloat16()) {
284303
LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got "
285304
<< mixed_precision_type_;
@@ -381,6 +400,11 @@ class MixedPrecisionPass : public MixedModeMutator {
381400
if (accumulation_dtype != output_dtype) {
382401
output = CastArg(output, GetType(output), output_dtype);
383402
}
403+
if (pre_call_node == root_ && keep_orig_output_dtype_) {
404+
if (original_dtype_[0] != output_dtype) {
405+
output = CastArg(output, GetType(output), original_dtype_[0]);
406+
}
407+
}
384408
return output;
385409
}
386410

@@ -396,6 +420,21 @@ class MixedPrecisionPass : public MixedModeMutator {
396420
Expr Rewrite_(const TupleNode* pre, const Expr& post) {
397421
// The old checked type in the expression may not be valid so clear it
398422
post->checked_type_ = Type(nullptr);
423+
if (pre == root_ && keep_orig_output_dtype_) {
424+
Array<Expr> new_expr;
425+
bool all_same = true;
426+
for (size_t i = 0; i < original_dtype_.size(); i++) {
427+
Expr output_element = GetField(post, i);
428+
Expr casted_element;
429+
auto output_element_type = transform::InferTypeLocal(output_element);
430+
casted_element = CastArg(output_element, output_element_type, original_dtype_[i]);
431+
new_expr.push_back(casted_element);
432+
all_same &= casted_element.same_as(output_element);
433+
}
434+
if (!all_same) {
435+
return Tuple(new_expr);
436+
}
437+
}
399438
return post;
400439
}
401440

@@ -421,11 +460,12 @@ class MixedPrecisionPass : public MixedModeMutator {
421460
}
422461

423462
// To access map of ops not registered for error reporting
424-
friend Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type,
425-
int missing_op_mode);
463+
friend Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype,
464+
const DataType& mixed_precision_type, int missing_op_mode);
426465
};
427466

428-
Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, int missing_op_mode) {
467+
Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype,
468+
const DataType& mixed_precision_type, int missing_op_mode) {
429469
/*
430470
missing_op_mode:
431471
@@ -436,7 +476,8 @@ Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, in
436476
ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2)
437477
<< " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode;
438478

439-
MixedPrecisionPass converter = MixedPrecisionPass(mixed_precision_type);
479+
MixedPrecisionPass converter =
480+
MixedPrecisionPass(expr, keep_orig_output_dtype, mixed_precision_type);
440481
auto result = converter.Mutate(expr);
441482

442483
for (auto it = converter.missing_ops_.begin();
@@ -460,7 +501,12 @@ namespace transform {
460501
Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) {
461502
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
462503
[=](Function f, IRModule m, PassContext pc) {
463-
return Downcast<Function>(ToMixedPrecision(f, mixed_precision_type, missing_op_mode));
504+
bool keep_orig_output_dtype = false;
505+
keep_orig_output_dtype = pc->GetConfig("relay.ToMixedPrecision.keep_orig_output_dtype",
506+
Bool(keep_orig_output_dtype))
507+
.value();
508+
return Downcast<Function>(
509+
ToMixedPrecision(f, keep_orig_output_dtype, mixed_precision_type, missing_op_mode));
464510
};
465511
return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {});
466512
}

tests/python/relay/test_to_mixed_precision.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,31 @@ def verify_mixed_precision_output_close(
4141
mixed_precision_dtype="float16",
4242
rtol: float = 1e-3,
4343
atol: float = 0,
44+
keep_orig_output_dtype=False,
4445
) -> tvm.runtime.Module:
4546

4647
mod = InferType()(mod)
4748
result_fp32 = run_module(mod, mod_params)
48-
fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
49-
result_fp16 = run_module(fp16_mod, mod_params)
49+
50+
if not keep_orig_output_dtype:
51+
fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
52+
result_fp16 = run_module(fp16_mod, mod_params)
53+
else:
54+
with tvm.transform.PassContext(
55+
config={"relay.ToMixedPrecision.keep_orig_output_dtype": True}
56+
):
57+
fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
58+
result_fp16 = run_module(fp16_mod, mod_params)
5059

5160
# Ensure the results are close
5261
for fp32, fp16 in zip(result_fp32, result_fp16):
5362
np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol)
5463

64+
if keep_orig_output_dtype:
65+
assert (
66+
np.array(result_fp16).dtype == np.array(result_fp32).dtype
67+
), "output type and original type mismatch"
68+
5569
return fp16_mod
5670

5771

@@ -117,16 +131,21 @@ def test_convert_single_conv():
117131
"data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
118132
"weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
119133
}
120-
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
134+
fp16_mod = verify_mixed_precision_output_close(
135+
mod, mod_params, atol=0.01, rtol=1e-3, keep_orig_output_dtype=True
136+
)
121137

122138
expected_mod = tvm.IRModule.from_expr(
123-
relay.nn.conv2d(
124-
relay.cast(data, "float16"),
125-
relay.cast(weight, "float16"),
126-
strides=(1, 1),
127-
padding=(1, 1),
128-
out_dtype="float16",
129-
),
139+
relay.cast(
140+
relay.nn.conv2d(
141+
relay.cast(data, "float16"),
142+
relay.cast(weight, "float16"),
143+
strides=(1, 1),
144+
padding=(1, 1),
145+
out_dtype="float16",
146+
),
147+
"float32",
148+
)
130149
)
131150
expected_mod = tvm.relay.transform.InferType()(expected_mod)
132151

0 commit comments

Comments
 (0)