Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/relay/transforms/to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ class MixedPrecisionPass : public MixedModeMutator {
If ignore_non_float, then ignore non-floating types.
*/
if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
bool is_supported_floating_point_type =
(tensor_type->dtype).is_float() || (tensor_type->dtype).is_bfloat16();
return (ignore_non_float && !is_supported_floating_point_type) ||
tensor_type->dtype == mixed_precision_type_;
} else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
for (Type t : tuple_type->fields) {
Expand Down
29 changes: 29 additions & 0 deletions tests/python/relay/test_to_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,5 +459,34 @@ def test_batch_matmul_simple():
assert tvm.ir.structural_equal(expected_mod, output_mod)


def test_convert_follow_node_with_integer_arguments():
"""Tests the conversion of a follow op with integer arguments + constant float args.

The follow op should convert the floating point argument into fp16 as constants/vars
will always be converted if safe to do so.
"""

data = relay.var("data", shape=[1, 10], dtype="float32")

# We add have an addition to make sure the input indices to take are not a
# var (which are always casted if safe)
indices = relay.var("indices", shape=[1, 1], dtype="int32") + relay.const(0, dtype="int32")
take = relay.take(data, indices, axis=0)
mod = tvm.IRModule.from_expr(take)

mod_params = {
"data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"),
"indices": np.array([[0]]).astype("int32"),
}
output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01)

# Create expected module
data = relay.cast(relay.var("data", shape=[1, 10]), "float16")
take = relay.take(data, indices, axis=0)
expected_mod = tvm.IRModule.from_expr(take)
expected_mod = InferType()(expected_mod)
assert tvm.ir.structural_equal(expected_mod, output_mod)


if __name__ == "__main__":
pytest.main([__file__])