Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions keras_nlp/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import json
import os

import keras_core

_MULTI_BACKEND = False
_USE_KERAS_3 = False

# Set Keras base dir path given KERAS_HOME env variable, if applicable.
# Otherwise either ~/.keras or /tmp.
Expand Down Expand Up @@ -59,16 +58,47 @@
# Except permission denied.
pass

# Use keras-core if KERAS_BACKEND is set in the environment.
# If KERAS_BACKEND is set in the environment use multi-backend keras.
if "KERAS_BACKEND" in os.environ and os.environ["KERAS_BACKEND"]:
_MULTI_BACKEND = True


def detect_if_tensorflow_uses_keras_3():
# We follow the version of keras that tensorflow is configured to use.
from tensorflow import keras

# Note that only recent versions of keras have a `version()` function.
if hasattr(keras, "version") and keras.version().startswith("3."):
return True

# No `keras.version()` means we are on an old version of keras.
return False


_USE_KERAS_3 = detect_if_tensorflow_uses_keras_3()
if _USE_KERAS_3:
_MULTI_BACKEND = True


def keras_3():
"""Check if Keras 3 is being used."""
return _USE_KERAS_3


def multi_backend():
"""Check if keras_core is enabled."""
"""Check if multi-backend Keras is enabled."""
return _MULTI_BACKEND


def backend():
"""Check the backend framework."""
return "tensorflow" if not multi_backend() else keras_core.config.backend()
if not multi_backend():
return "tensorflow"
if not keras_3():
import keras_core

return keras_core.config.backend()

from tensorflow import keras

return keras.config.backend()
6 changes: 4 additions & 2 deletions keras_nlp/backend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

import tensorflow as tf

from keras_nlp.backend.config import multi_backend
from keras_nlp.backend import config

if multi_backend():
if config.keras_3():
from keras import * # noqa: F403, F401
elif config.multi_backend():
from keras_core import * # noqa: F403, F401
else:
from tensorflow.keras import * # noqa: F403, F401
Expand Down
24 changes: 8 additions & 16 deletions keras_nlp/backend/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import keras_core
import tensorflow as tf
from keras_nlp.backend import config

from keras_nlp.backend.config import multi_backend

if multi_backend():
from keras_core.src.ops import * # noqa: F403, F401
if config.keras_3():
from keras.ops import * # noqa: F403, F401
else:
from keras_core.src.backend.tensorflow import * # noqa: F403, F401
from keras_core.src.backend.tensorflow.core import * # noqa: F403, F401
from keras_core.src.backend.tensorflow.math import * # noqa: F403, F401
from keras_core.src.backend.tensorflow.nn import * # noqa: F403, F401
from keras_core.src.backend.tensorflow.numpy import * # noqa: F403, F401

from keras_core.ops import * # noqa: F403, F401

if keras_core.config.backend() == "tensorflow" or not multi_backend():
if config.backend() == "tensorflow":
import tensorflow as tf
from tensorflow.experimental import numpy as tfnp

def take_along_axis(x, indices, axis=None):
# TODO: move this workaround for dynamic shapes into keras-core.
Expand All @@ -46,6 +40,4 @@ def take_along_axis(x, indices, axis=None):
indices = tf.squeeze(indices, leftover_axes)
return tf.gather(x, indices, batch_dims=axis)
# Otherwise, fall back to the tfnp call.
return keras_core.src.backend.tensorflow.numpy.take_along_axis(
x, indices, axis=axis
)
return tfnp.take_along_axis(x, indices, axis=axis)
8 changes: 4 additions & 4 deletions keras_nlp/backend/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.backend.config import multi_backend
from keras_nlp.backend import config

if multi_backend():
from keras_core.random import * # noqa: F403, F401
if config.keras_3():
from keras.random import * # noqa: F403, F401
else:
from keras_core.src.backend.tensorflow.random import * # noqa: F403, F401
from keras_core.random import * # noqa: F403, F401
7 changes: 4 additions & 3 deletions keras_nlp/layers/modeling/cached_multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from keras_nlp.backend import config
from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.layers.modeling.cached_multi_head_attention import (
CachedMultiHeadAttention,
)
Expand All @@ -29,8 +30,8 @@ def test_layer_behaviors(self):
"key_dim": 4,
},
input_data={
"query": ops.random.uniform(shape=(2, 4, 6)),
"value": ops.random.uniform(shape=(2, 4, 6)),
"query": random.uniform(shape=(2, 4, 6)),
"value": random.uniform(shape=(2, 4, 6)),
},
expected_output_shape=(2, 4, 6),
expected_num_trainable_weights=8,
Expand All @@ -48,7 +49,7 @@ def test_cache_call_is_correct(self):
hidden_dim = num_heads * key_dim

input_shape = (batch_size, seq_len, hidden_dim)
x = ops.random.uniform(shape=input_shape)
x = random.uniform(shape=input_shape)
input_cache = ops.zeros((batch_size, 2, seq_len, num_heads, key_dim))
# Use a causal mask.
mask = ops.tril(ops.ones((seq_len, seq_len)))
Expand Down
4 changes: 2 additions & 2 deletions keras_nlp/layers/modeling/f_net_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.layers.modeling.f_net_encoder import FNetEncoder
from keras_nlp.tests.test_case import TestCase

Expand All @@ -29,7 +29,7 @@ def test_layer_behaviors(self):
"kernel_initializer": "HeNormal",
"bias_initializer": "Zeros",
},
input_data=ops.random.uniform(shape=(2, 4, 6)),
input_data=random.uniform(shape=(2, 4, 6)),
expected_output_shape=(2, 4, 6),
expected_num_trainable_weights=8,
expected_num_non_trainable_variables=1,
Expand Down
10 changes: 5 additions & 5 deletions keras_nlp/layers/modeling/masked_lm_head_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.layers.modeling.masked_lm_head import MaskedLMHead
from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding
from keras_nlp.tests.test_case import TestCase
Expand All @@ -29,8 +29,8 @@ def test_layer_behaviors(self):
"bias_initializer": "Zeros",
},
input_data={
"inputs": ops.random.uniform(shape=(4, 10, 16)),
"mask_positions": ops.random.randint(
"inputs": random.uniform(shape=(4, 10, 16)),
"mask_positions": random.randint(
minval=0, maxval=10, shape=(4, 5)
),
},
Expand All @@ -51,8 +51,8 @@ def test_layer_behaviors_with_embedding(self):
"token_embedding": embedding,
},
input_data={
"inputs": ops.random.uniform(shape=(4, 10, 16)),
"mask_positions": ops.random.randint(
"inputs": random.uniform(shape=(4, 10, 16)),
"mask_positions": random.randint(
minval=0, maxval=10, shape=(4, 5)
),
},
Expand Down
7 changes: 4 additions & 3 deletions keras_nlp/layers/modeling/position_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.layers.modeling.position_embedding import PositionEmbedding
from keras_nlp.tests.test_case import TestCase

Expand All @@ -34,7 +35,7 @@ def test_layer_behaviors(self):
init_kwargs={
"sequence_length": 21,
},
input_data=ops.random.uniform(shape=(4, 21, 30)),
input_data=random.uniform(shape=(4, 21, 30)),
expected_output_shape=(4, 21, 30),
expected_num_trainable_weights=1,
)
Expand All @@ -45,7 +46,7 @@ def test_layer_behaviors_4d(self):
init_kwargs={
"sequence_length": 21,
},
input_data=ops.random.uniform(shape=(4, 5, 21, 30)),
input_data=random.uniform(shape=(4, 5, 21, 30)),
expected_output_shape=(4, 5, 21, 30),
expected_num_trainable_weights=1,
)
Expand Down Expand Up @@ -145,7 +146,7 @@ def test_callable_initializer(self):
def test_start_index(self):
batch_size, seq_length, feature_size = 2, 3, 4
layer = PositionEmbedding(seq_length)
data = ops.random.uniform(shape=(batch_size, seq_length, feature_size))
data = random.uniform(shape=(batch_size, seq_length, feature_size))
full_output = layer(data)
sequential_output = ops.zeros((batch_size, seq_length, feature_size))
for i in range(seq_length):
Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/layers/modeling/reversible_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding
from keras_nlp.tests.test_case import TestCase

Expand All @@ -38,7 +39,7 @@ def test_layer_behaviors_tied(self, tie_weights):
"tie_weights": tie_weights,
"embeddings_initializer": "HeNormal",
},
input_data=ops.random.randint(minval=0, maxval=100, shape=(4, 10)),
input_data=random.randint(minval=0, maxval=100, shape=(4, 10)),
expected_output_shape=(4, 10, 32),
expected_num_trainable_weights=1 if tie_weights else 2,
)
Expand Down
7 changes: 4 additions & 3 deletions keras_nlp/layers/modeling/rotary_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_nlp.tests.test_case import TestCase

Expand All @@ -28,7 +29,7 @@ def test_layer_behaviors(self):
"sequence_axis": 1,
"feature_axis": -1,
},
input_data=ops.random.uniform(shape=(2, 4, 6)),
input_data=random.uniform(shape=(2, 4, 6)),
expected_output_shape=(2, 4, 6),
)

Expand All @@ -38,7 +39,7 @@ def test_layer_behaviors_4d(self):
init_kwargs={
"max_wavelength": 1000,
},
input_data=ops.random.uniform(shape=(2, 8, 4, 6)),
input_data=random.uniform(shape=(2, 8, 4, 6)),
expected_output_shape=(2, 8, 4, 6),
)

Expand Down Expand Up @@ -86,7 +87,7 @@ def test_output_correct_values(self):
def test_start_index(self):
batch_size, seq_length, feature_size = 2, 3, 4
layer = RotaryEmbedding(seq_length)
data = ops.random.uniform(shape=(batch_size, seq_length, feature_size))
data = random.uniform(shape=(batch_size, seq_length, feature_size))
full_output = layer(data)
sequential_output = ops.zeros((batch_size, seq_length, feature_size))
for i in range(seq_length):
Expand Down
9 changes: 5 additions & 4 deletions keras_nlp/layers/modeling/sine_position_encoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.layers.modeling.sine_position_encoding import (
SinePositionEncoding,
)
Expand All @@ -27,7 +28,7 @@ def test_layer_behaviors(self):
init_kwargs={
"max_wavelength": 10000,
},
input_data=ops.random.uniform(shape=(2, 4, 6)),
input_data=random.uniform(shape=(2, 4, 6)),
expected_output_shape=(2, 4, 6),
)

Expand All @@ -37,7 +38,7 @@ def test_layer_behaviors_4d(self):
init_kwargs={
"max_wavelength": 10000,
},
input_data=ops.random.uniform(shape=(1, 2, 4, 6)),
input_data=random.uniform(shape=(1, 2, 4, 6)),
expected_output_shape=(1, 2, 4, 6),
)

Expand Down Expand Up @@ -85,7 +86,7 @@ def test_output_correct_values(self):
pos_encoding,
]
)
input = ops.random.uniform(shape=[1, 4, 6])
input = random.uniform(shape=[1, 4, 6])
output = model(input)

# comapre position encoding values for position 0 and 3
Expand All @@ -97,7 +98,7 @@ def test_output_correct_values(self):
def test_start_index(self):
batch_size, seq_length, feature_size = 2, 3, 4
layer = SinePositionEncoding()
data = ops.random.uniform(shape=(batch_size, seq_length, feature_size))
data = random.uniform(shape=(batch_size, seq_length, feature_size))
full_output = layer(data)
sequential_output = ops.zeros((batch_size, seq_length, feature_size))
for i in range(seq_length):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.layers.modeling.token_and_position_embedding import (
TokenAndPositionEmbedding,
)
Expand All @@ -32,7 +33,7 @@ def test_layer_behaviors(self):
"embedding_dim": 3,
"embeddings_initializer": keras.initializers.Constant(1.0),
},
input_data=ops.random.randint(minval=0, maxval=5, shape=(2, 4)),
input_data=random.randint(minval=0, maxval=5, shape=(2, 4)),
expected_output_shape=(2, 4, 3),
expected_output_data=ops.ones((2, 4, 3)) * 2,
expected_num_trainable_weights=2,
Expand Down
Loading