Skip to content

Commit

Permalink
Allow convert_to_tensor to take a value with the wrong dtype on T…
Browse files Browse the repository at this point in the history
…ensorflow. (#20513)

`ops.convert_to_tensor(1.0, "int32")` would fail with the TensorFlow backend. This case is now supported.

Note that other backends already supported this.
  • Loading branch information
hertschuh authored Nov 19, 2024
1 parent f17a2c2 commit b5ddf32
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
8 changes: 5 additions & 3 deletions keras/src/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras.src import tree
from keras.src.backend.common import KerasVariable
from keras.src.backend.common import global_state
from keras.src.backend.common import is_int_dtype
from keras.src.backend.common import standardize_dtype
from keras.src.backend.common.backend_utils import slice_along_axis
from keras.src.backend.common.keras_tensor import KerasTensor
Expand Down Expand Up @@ -119,9 +120,10 @@ def convert_to_tensor(x, dtype=None, sparse=None):
if dtype is not None:
dtype = standardize_dtype(dtype)
if not tf.is_tensor(x):
if dtype == "bool":
# TensorFlow boolean conversion is stricter than other backends.
# It does not allow ints. We convert without dtype and cast instead.
if dtype == "bool" or is_int_dtype(dtype):
# TensorFlow conversion is stricter than other backends, it does not
# allow ints for bools or floats for ints. We convert without dtype
# and cast instead.
x = tf.convert_to_tensor(x)
return tf.cast(x, dtype)
return tf.convert_to_tensor(x, dtype=dtype)
Expand Down
3 changes: 3 additions & 0 deletions keras/src/ops/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,9 @@ class CoreOpsDtypeTest(testing.TestCase):
(bool(0), None, "bool"),
(int(0), None, "int32"),
(float(0), None, backend.floatx()),
(1, "bool", "bool"),
(1.0, "int32", "int32"),
(1.0, "float32", "float32"),
([False, True, False], None, "bool"),
([1, 2, 3], None, "int32"),
([1.0, 2.0, 3.0], None, backend.floatx()),
Expand Down

0 comments on commit b5ddf32

Please sign in to comment.