diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index d8c7aa2ffcfa..4ad3482f7464 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -194,7 +194,9 @@ class MixedPrecisionPass : public MixedModeMutator { If ignore_non_float, then ignore non-floating types. */ if (const TensorTypeNode* tensor_type = t.as()) { - return (!ignore_non_float || (tensor_type->dtype).is_float()) && + bool is_supported_floating_point_type = + (tensor_type->dtype).is_float() || (tensor_type->dtype).is_bfloat16(); + return (ignore_non_float && !is_supported_floating_point_type) || tensor_type->dtype == mixed_precision_type_; } else if (const TupleTypeNode* tuple_type = t.as()) { for (Type t : tuple_type->fields) { diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 472f98715ec5..2afd6ff247ab 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -459,5 +459,34 @@ def test_batch_matmul_simple(): assert tvm.ir.structural_equal(expected_mod, output_mod) +def test_convert_follow_node_with_integer_arguments(): + """Tests the conversion of a follow op with integer arguments + constant float args. + + The follow op should convert the floating point argument into fp16 as constants/vars + will always be converted if safe to do so. + """ + + data = relay.var("data", shape=[1, 10], dtype="float32") + + # We use an addition to make sure the input indices are not a var + # (which are always casted if safe) + indices = relay.var("indices", shape=[1, 1], dtype="int32") + relay.const(0, dtype="int32") + take = relay.take(data, indices, axis=0) + mod = tvm.IRModule.from_expr(take) + + mod_params = { + "data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"), + "indices": np.array([[0]]).astype("int32"), + } + output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) + + # Create expected module + data = relay.cast(relay.var("data", shape=[1, 10]), "float16") + take = relay.take(data, indices, axis=0) + expected_mod = tvm.IRModule.from_expr(take) + expected_mod = InferType()(expected_mod) + assert tvm.ir.structural_equal(expected_mod, output_mod) + + if __name__ == "__main__": pytest.main([__file__])