Skip to content

Commit 9258997

Browse files
[AMP] Fix IsMixedPrecisionType Edge Case (#9856)
* fix ismixedprecisiontype * fix test * jostle * Update tests/python/relay/test_to_mixed_precision.py Co-authored-by: Cody Yu <[email protected]> Co-authored-by: Cody Yu <[email protected]>
1 parent 2493aeb commit 9258997

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

src/relay/transforms/to_mixed_precision.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,9 @@ class MixedPrecisionPass : public MixedModeMutator {
194194
If ignore_non_float, then ignore non-floating types.
195195
*/
196196
if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
197-
return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
197+
bool is_supported_floating_point_type =
198+
(tensor_type->dtype).is_float() || (tensor_type->dtype).is_bfloat16();
199+
return (ignore_non_float && !is_supported_floating_point_type) ||
198200
tensor_type->dtype == mixed_precision_type_;
199201
} else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
200202
for (Type t : tuple_type->fields) {

tests/python/relay/test_to_mixed_precision.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,5 +459,34 @@ def test_batch_matmul_simple():
459459
assert tvm.ir.structural_equal(expected_mod, output_mod)
460460

461461

462+
def test_convert_follow_node_with_integer_arguments():
463+
"""Tests the conversion of a follow op with integer arguments + constant float args.
464+
465+
The follow op should convert the floating point argument into fp16 as constants/vars
466+
will always be converted if safe to do so.
467+
"""
468+
469+
data = relay.var("data", shape=[1, 10], dtype="float32")
470+
471+
# We use an addition to make sure the input indices are not a var
472+
# (which are always casted if safe)
473+
indices = relay.var("indices", shape=[1, 1], dtype="int32") + relay.const(0, dtype="int32")
474+
take = relay.take(data, indices, axis=0)
475+
mod = tvm.IRModule.from_expr(take)
476+
477+
mod_params = {
478+
"data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"),
479+
"indices": np.array([[0]]).astype("int32"),
480+
}
481+
output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01)
482+
483+
# Create expected module
484+
data = relay.cast(relay.var("data", shape=[1, 10]), "float16")
485+
take = relay.take(data, indices, axis=0)
486+
expected_mod = tvm.IRModule.from_expr(take)
487+
expected_mod = InferType()(expected_mod)
488+
assert tvm.ir.structural_equal(expected_mod, output_mod)
489+
490+
462491
if __name__ == "__main__":
463492
pytest.main([__file__])

0 commit comments

Comments
 (0)