From 35573738b000f13aa5f49d65e0766e42fb2db8ea Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Tue, 17 Oct 2023 22:44:18 +0000 Subject: [PATCH 1/4] Change TF ops to Keras Core ops --- keras_nlp/models/t5/t5_backbone.py | 6 +- keras_nlp/models/t5/t5_backbone_test.py | 2 - keras_nlp/models/t5/t5_layer_norm.py | 10 +-- .../models/t5/t5_multi_head_attention.py | 83 ++++++++++--------- keras_nlp/models/t5/t5_transformer_layer.py | 28 +++++-- 5 files changed, 72 insertions(+), 57 deletions(-) diff --git a/keras_nlp/models/t5/t5_backbone.py b/keras_nlp/models/t5/t5_backbone.py index 7514cc51ae..0ea07c6a9b 100644 --- a/keras_nlp/models/t5/t5_backbone.py +++ b/keras_nlp/models/t5/t5_backbone.py @@ -81,8 +81,6 @@ def __init__( tie_embedding_weights=False, **kwargs, ): - assert_tf_backend(self.__class__.__name__) - # Encoder inputs encoder_token_ids = keras.Input( shape=(None,), dtype="int32", name="encoder_token_ids" @@ -162,7 +160,7 @@ def __init__( position_bias = None for i in range(num_layers): - x, position_bias = T5TransformerLayer( + output = T5TransformerLayer( is_decoder=True, hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, @@ -181,6 +179,8 @@ def __init__( encoder_attention_mask=encoder_attention_mask, use_causal_mask=True, ) + if isinstance(output, tuple): + x, position_bias = output x = T5LayerNorm( epsilon=layer_norm_epsilon, diff --git a/keras_nlp/models/t5/t5_backbone_test.py b/keras_nlp/models/t5/t5_backbone_test.py index 476304c566..96b9700c11 100644 --- a/keras_nlp/models/t5/t5_backbone_test.py +++ b/keras_nlp/models/t5/t5_backbone_test.py @@ -22,8 +22,6 @@ from keras_nlp.models.t5.t5_backbone import T5Backbone from keras_nlp.tests.test_case import TestCase - -@pytest.mark.tf_only class T5Test(TestCase): def setUp(self): self.backbone = T5Backbone( diff --git a/keras_nlp/models/t5/t5_layer_norm.py b/keras_nlp/models/t5/t5_layer_norm.py index 7cfdb2315e..5695f4b452 100644 --- a/keras_nlp/models/t5/t5_layer_norm.py +++ b/keras_nlp/models/t5/t5_layer_norm.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - from keras_nlp.backend import keras - +from keras_nlp.backend import ops class T5LayerNorm(keras.layers.Layer): def __init__(self, epsilon=1e-6, **kwargs): @@ -31,8 +29,8 @@ def build(self, input_shape): self.built = True def call(self, hidden_states): - variance = tf.math.reduce_mean( - tf.math.square(hidden_states), axis=-1, keepdims=True + variance = ops.mean( + ops.square(hidden_states), axis=-1, keepdims=True ) - hidden_states = hidden_states * tf.math.rsqrt(variance + self.epsilon) + hidden_states = hidden_states * ops.rsqrt(variance + self.epsilon) return self.weight * hidden_states diff --git a/keras_nlp/models/t5/t5_multi_head_attention.py b/keras_nlp/models/t5/t5_multi_head_attention.py index 479de51e7d..38c18e13c0 100644 --- a/keras_nlp/models/t5/t5_multi_head_attention.py +++ b/keras_nlp/models/t5/t5_multi_head_attention.py @@ -12,18 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf -from tensorflow.compiler.tf2xla.python.xla import dynamic_slice +import numpy as np from keras_nlp.backend import keras +from keras_nlp.backend import ops def shape_list(tensor): - dynamic = tf.shape(tensor) - if tensor.shape == tf.TensorShape(None): - return dynamic - static = tensor.shape.as_list() - return [dynamic[i] if s is None else s for i, s in enumerate(static)] + return ops.shape(tensor) + # dynamic = tf.shape(tensor) + # if tensor.shape == tf.TensorShape(None): + # return dynamic + # static = tensor.shape.as_list() + # return [dynamic[i] if s is None else s for i, s in enumerate(static)] class T5MultiHeadAttention(keras.layers.Layer): @@ -123,39 +124,39 @@ def _relative_position_bucket( if bidirectional: num_buckets //= 2 relative_buckets += ( - tf.cast( - tf.math.greater(relative_position, 0), + ops.cast( + ops.greater(relative_position, 0), dtype=relative_position.dtype, ) * num_buckets ) - relative_position = tf.math.abs(relative_position) + relative_position = ops.abs(relative_position) else: - relative_position = -tf.math.minimum(relative_position, 0) + relative_position = -ops.minimum(relative_position, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 - is_small = tf.math.less(relative_position, max_exact) - relative_position_if_large = max_exact + tf.cast( - tf.math.log( - tf.cast(relative_position, "float32") - / tf.cast(max_exact, "float32") + is_small = ops.less(relative_position, max_exact) + relative_position_if_large = max_exact + ops.cast( + ops.log( + ops.cast(relative_position, "float32") + / ops.cast(max_exact, "float32") ) - / tf.math.log(max_distance / max_exact) + / ops.cast(ops.log(max_distance / max_exact), "float32") * (num_buckets - max_exact), dtype=relative_position.dtype, ) - relative_position_if_large = tf.math.minimum( + relative_position_if_large = ops.minimum( relative_position_if_large, num_buckets - 1 ) - relative_buckets += tf.where( + relative_buckets += ops.where( is_small, relative_position, relative_position_if_large ) return relative_buckets def compute_bias(self, query_length, key_length): """Compute binned relative position bias""" - context_position = tf.range(query_length)[:, None] - memory_position = tf.range(key_length)[None, :] + context_position = ops.arange(query_length)[:, None] + memory_position = ops.arange(key_length)[None, :] relative_position = ( memory_position - context_position ) # shape (query_length, key_length) @@ -165,11 +166,11 @@ def compute_bias(self, query_length, key_length): num_buckets=self.relative_attention_buckets, max_distance=self.relative_attention_max_distance, ) - values = tf.gather( - self.relative_attention_bias, relative_position_bucket + values = ops.take( + self.relative_attention_bias, relative_position_bucket, axis=0 ) # shape (query_length, key_length, num_heads) - values = tf.expand_dims( - tf.transpose(values, [2, 0, 1]), axis=0 + values = ops.expand_dims( + ops.transpose(values, axes=(2, 0, 1)), axis=0 ) # shape (1, num_heads, query_length, key_length) return values @@ -209,17 +210,17 @@ def call( ) def shape(hidden_states): - return tf.transpose( - tf.reshape( + return ops.transpose( + ops.reshape( hidden_states, (batch_size, -1, self.num_heads, self.key_value_dim), ), - perm=(0, 2, 1, 3), + axes=(0, 2, 1, 3), ) def unshape(hidden_states): - return tf.reshape( - tf.transpose(hidden_states, perm=(0, 2, 1, 3)), + return ops.reshape( + ops.transpose(hidden_states, axes=(0, 2, 1, 3)), (batch_size, -1, self.inner_dim), ) @@ -240,7 +241,7 @@ def project( if key_value_states is None: # self-attention # (batch_size, num_heads, key_length, dim_per_head) - hidden_states = tf.concat( + hidden_states = ops.concat( [past_key_value, hidden_states], axis=2 ) else: @@ -267,13 +268,13 @@ def project( past_key_value[1] if past_key_value is not None else None, ) - scores = tf.einsum( + scores = ops.einsum( "bnqd,bnkd->bnqk", query_states, key_states ) # (batch_size, num_heads, query_length, key_length) if position_bias is None: if not self.use_relative_attention_bias: - position_bias = tf.zeros( + position_bias = ops.zeros( (1, self.num_heads, real_seq_length, key_length), self.compute_dtype, ) @@ -289,10 +290,10 @@ def project( # we might have a padded past structure, # in which case we want to fetch the position bias slice # right after the most recently filled past index - most_recently_filled_past_index = tf.reduce_max( - tf.where(past_key_value[0][0, 0, :, 0] != 0.0) + most_recently_filled_past_index = ops.amax( + ops.where(past_key_value[0][0, 0, :, 0] != 0.0) ) - position_bias = dynamic_slice( + position_bias = ops.slice( position_bias, (0, 0, most_recently_filled_past_index + 1, 0), (1, self.num_heads, seq_length, real_seq_length), @@ -300,13 +301,13 @@ def project( if mask is not None: # Add a new mask axis for the head dim. - mask = mask[:, tf.newaxis, :, :] + mask = mask[:, np.newaxis, :, :] # Add a very large negative position bias for masked positions. - mask = (1.0 - tf.cast(mask, position_bias.dtype)) * -1e9 + mask = (1.0 - ops.cast(mask, position_bias.dtype)) * -1e9 position_bias = position_bias + mask scores += position_bias - weights = tf.nn.softmax( + weights = ops.nn.softmax( scores, axis=-1 ) # (batch_size, num_heads, query_length, key_length) weights = self.dropout_layer( @@ -315,9 +316,9 @@ def project( # Optionally mask heads if layer_head_mask is not None: - weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights + weights = ops.reshape(layer_head_mask, (1, -1, 1, 1)) * weights - attention_output = tf.matmul( + attention_output = ops.matmul( weights, value_states ) # (batch_size, num_heads, query_length, dim_per_head) diff --git a/keras_nlp/models/t5/t5_transformer_layer.py b/keras_nlp/models/t5/t5_transformer_layer.py index ce4a28d67f..28b483f7a2 100644 --- a/keras_nlp/models/t5/t5_transformer_layer.py +++ b/keras_nlp/models/t5/t5_transformer_layer.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - from keras_nlp.backend import keras from keras_nlp.layers.modeling.transformer_layer_utils import ( compute_causal_mask, @@ -21,6 +19,7 @@ from keras_nlp.models.t5.t5_layer_norm import T5LayerNorm from keras_nlp.models.t5.t5_multi_head_attention import T5MultiHeadAttention +from keras_nlp.backend import ops class T5TransformerLayer(keras.layers.Layer): def __init__( @@ -92,6 +91,22 @@ def __init__( self.layer_norm = T5LayerNorm(epsilon=layer_norm_epsilon) self.dropout_layer = keras.layers.Dropout(dropout) + def build(self, input_shape): + self.self_attention_layer_norm.build(input_shape) + self.self_attention.build(input_shape) + self.self_attention_dropout.build(input_shape) + if self.is_decoder: + self.cross_attention.build(input_shape) + self.cross_attention_layer_norm.build(input_shape) + self.cross_attention_dropout.build(input_shape) + if self.use_gated_activation: + self.gate_projector.build(input_shape) + self.output_projector.build(input_shape) + self.layer_norm.build(input_shape) + self.dropout_layer.build(input_shape) + self.built = True + + def call( self, hidden_states, @@ -103,10 +118,10 @@ def call( training=False, ): if use_causal_mask: - shape = tf.shape(hidden_states) + shape = ops.shape(hidden_states) batch_size, length = shape[0], shape[1] causal_mask = compute_causal_mask(batch_size, length, length) - attention_mask = tf.cast(attention_mask, "int32") + attention_mask = ops.cast(attention_mask, "int32") attention_mask = causal_mask & attention_mask x = hidden_states # Intermediate result. @@ -147,4 +162,7 @@ def call( x = self.dropout_layer(x, training=training) x = x + residual - return x, position_bias + if position_bias is not None: + return x, position_bias + else: + return x From 13ea25a857f87054db8bf374f6429343d8517d11 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Tue, 17 Oct 2023 22:45:01 +0000 Subject: [PATCH 2/4] Fix formatting --- keras_nlp/layers/preprocessing/masked_lm_mask_generator.py | 6 +----- keras_nlp/models/t5/t5_backbone.py | 1 - keras_nlp/models/t5/t5_backbone_test.py | 1 + keras_nlp/models/t5/t5_layer_norm.py | 5 ++--- keras_nlp/models/t5/t5_transformer_layer.py | 3 +-- 5 files changed, 5 insertions(+), 11 deletions(-) diff --git a/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py b/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py index 74b2fd9811..2a7553e956 100644 --- a/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py +++ b/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py @@ -170,11 +170,7 @@ def __init__( def call(self, inputs): inputs, unbatched, rectangular = convert_to_ragged_batch(inputs) - ( - token_ids, - mask_positions, - mask_ids, - ) = tf_text.mask_language_model( + (token_ids, mask_positions, mask_ids,) = tf_text.mask_language_model( inputs, item_selector=self._random_selector, mask_values_chooser=self._mask_values_chooser, diff --git a/keras_nlp/models/t5/t5_backbone.py b/keras_nlp/models/t5/t5_backbone.py index 0ea07c6a9b..0bfceba418 100644 --- a/keras_nlp/models/t5/t5_backbone.py +++ b/keras_nlp/models/t5/t5_backbone.py @@ -19,7 +19,6 @@ from keras_nlp.models.t5.t5_layer_norm import T5LayerNorm from keras_nlp.models.t5.t5_transformer_layer import T5TransformerLayer from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.tensor_utils import assert_tf_backend @keras_nlp_export("keras_nlp.models.T5Backbone") diff --git a/keras_nlp/models/t5/t5_backbone_test.py b/keras_nlp/models/t5/t5_backbone_test.py index 96b9700c11..7611ef44fd 100644 --- a/keras_nlp/models/t5/t5_backbone_test.py +++ b/keras_nlp/models/t5/t5_backbone_test.py @@ -22,6 +22,7 @@ from keras_nlp.models.t5.t5_backbone import T5Backbone from keras_nlp.tests.test_case import TestCase + class T5Test(TestCase): def setUp(self): self.backbone = T5Backbone( diff --git a/keras_nlp/models/t5/t5_layer_norm.py b/keras_nlp/models/t5/t5_layer_norm.py index 5695f4b452..b4f157c004 100644 --- a/keras_nlp/models/t5/t5_layer_norm.py +++ b/keras_nlp/models/t5/t5_layer_norm.py @@ -15,6 +15,7 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops + class T5LayerNorm(keras.layers.Layer): def __init__(self, epsilon=1e-6, **kwargs): super().__init__(**kwargs) @@ -29,8 +30,6 @@ def build(self, input_shape): self.built = True def call(self, hidden_states): - variance = ops.mean( - ops.square(hidden_states), axis=-1, keepdims=True - ) + variance = ops.mean(ops.square(hidden_states), axis=-1, keepdims=True) hidden_states = hidden_states * ops.rsqrt(variance + self.epsilon) return self.weight * hidden_states diff --git a/keras_nlp/models/t5/t5_transformer_layer.py b/keras_nlp/models/t5/t5_transformer_layer.py index 28b483f7a2..253a6abe84 100644 --- a/keras_nlp/models/t5/t5_transformer_layer.py +++ b/keras_nlp/models/t5/t5_transformer_layer.py @@ -13,13 +13,13 @@ # limitations under the License. from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.layers.modeling.transformer_layer_utils import ( compute_causal_mask, ) from keras_nlp.models.t5.t5_layer_norm import T5LayerNorm from keras_nlp.models.t5.t5_multi_head_attention import T5MultiHeadAttention -from keras_nlp.backend import ops class T5TransformerLayer(keras.layers.Layer): def __init__( @@ -106,7 +106,6 @@ def build(self, input_shape): self.dropout_layer.build(input_shape) self.built = True - def call( self, hidden_states, From 90dcf7d9eaf16b3bf8e7353c4000bac7565edd18 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Tue, 17 Oct 2023 23:53:53 +0000 Subject: [PATCH 3/4] Remove build override --- keras_nlp/models/t5/t5_backbone.py | 4 +++- keras_nlp/models/t5/t5_transformer_layer.py | 15 --------------- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/keras_nlp/models/t5/t5_backbone.py b/keras_nlp/models/t5/t5_backbone.py index 0bfceba418..13db116f43 100644 --- a/keras_nlp/models/t5/t5_backbone.py +++ b/keras_nlp/models/t5/t5_backbone.py @@ -118,7 +118,7 @@ def __init__( position_bias = None for i in range(num_layers): - x, position_bias = T5TransformerLayer( + output = T5TransformerLayer( is_decoder=False, hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, @@ -135,6 +135,8 @@ def __init__( position_bias=position_bias, use_causal_mask=False, ) + if isinstance(output, tuple): + x, position_bias = output x = T5LayerNorm( epsilon=layer_norm_epsilon, diff --git a/keras_nlp/models/t5/t5_transformer_layer.py b/keras_nlp/models/t5/t5_transformer_layer.py index 253a6abe84..22c2dc1c74 100644 --- a/keras_nlp/models/t5/t5_transformer_layer.py +++ b/keras_nlp/models/t5/t5_transformer_layer.py @@ -91,21 +91,6 @@ def __init__( self.layer_norm = T5LayerNorm(epsilon=layer_norm_epsilon) self.dropout_layer = keras.layers.Dropout(dropout) - def build(self, input_shape): - self.self_attention_layer_norm.build(input_shape) - self.self_attention.build(input_shape) - self.self_attention_dropout.build(input_shape) - if self.is_decoder: - self.cross_attention.build(input_shape) - self.cross_attention_layer_norm.build(input_shape) - self.cross_attention_dropout.build(input_shape) - if self.use_gated_activation: - self.gate_projector.build(input_shape) - self.output_projector.build(input_shape) - self.layer_norm.build(input_shape) - self.dropout_layer.build(input_shape) - self.built = True - def call( self, hidden_states, From c629dd43071464dc5aac74f13efa274ad7ffcc70 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Wed, 18 Oct 2023 00:08:50 +0000 Subject: [PATCH 4/4] Fix formatting and remove unneeded function --- .../preprocessing/masked_lm_mask_generator.py | 6 +++++- keras_nlp/models/t5/t5_multi_head_attention.py | 15 +++------------ 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py b/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py index 2a7553e956..74b2fd9811 100644 --- a/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py +++ b/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py @@ -170,7 +170,11 @@ def __init__( def call(self, inputs): inputs, unbatched, rectangular = convert_to_ragged_batch(inputs) - (token_ids, mask_positions, mask_ids,) = tf_text.mask_language_model( + ( + token_ids, + mask_positions, + mask_ids, + ) = tf_text.mask_language_model( inputs, item_selector=self._random_selector, mask_values_chooser=self._mask_values_chooser, diff --git a/keras_nlp/models/t5/t5_multi_head_attention.py b/keras_nlp/models/t5/t5_multi_head_attention.py index 38c18e13c0..5cb59769dc 100644 --- a/keras_nlp/models/t5/t5_multi_head_attention.py +++ b/keras_nlp/models/t5/t5_multi_head_attention.py @@ -18,15 +18,6 @@ from keras_nlp.backend import ops -def shape_list(tensor): - return ops.shape(tensor) - # dynamic = tf.shape(tensor) - # if tensor.shape == tf.TensorShape(None): - # return dynamic - # static = tensor.shape.as_list() - # return [dynamic[i] if s is None else s for i, s in enumerate(static)] - - class T5MultiHeadAttention(keras.layers.Layer): # This layer is adapted from Hugging Face # Ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_tf_t5.py @@ -187,7 +178,7 @@ def call( ): # Input is (batch_size, query_length, dim) # past_key_value[0] is (batch_size, num_heads, q_len - 1, dim_per_head) - batch_size, seq_length = shape_list(hidden_states)[:2] + batch_size, seq_length = ops.shape(hidden_states)[:2] real_seq_length = seq_length @@ -198,7 +189,7 @@ def call( f"keys and values. Got {len(past_key_value)} past states." ) real_seq_length += ( - shape_list(past_key_value[0])[2] + ops.shape(past_key_value[0])[2] if query_length is None else query_length ) @@ -206,7 +197,7 @@ def call( key_length = ( real_seq_length if key_value_states is None - else shape_list(key_value_states)[1] + else ops.shape(key_value_states)[1] ) def shape(hidden_states):