From f19f3a3aac0fd92a026138072cedf6a8e71cb67f Mon Sep 17 00:00:00 2001 From: apupneja Date: Tue, 14 Feb 2023 23:57:00 +0530 Subject: [PATCH 1/3] adding support for tf.RaggedTensor --- keras_nlp/layers/sine_position_encoding.py | 32 ++++++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/keras_nlp/layers/sine_position_encoding.py b/keras_nlp/layers/sine_position_encoding.py index 7ae2ca11ba..5d29eb51cd 100644 --- a/keras_nlp/layers/sine_position_encoding.py +++ b/keras_nlp/layers/sine_position_encoding.py @@ -31,6 +31,11 @@ class SinePositionEncoding(keras.layers.Layer): positional encoding the same size as the embedded token tensor, which can be added directly to the embedded token tensor. + This layer optionally accepts `tf.RaggedTensor`s as inputs to process + batches of sequences of different lengths. The one ragged dimension must be + the dimension that corresponds to the sequence, that is, the penultimate + dimension. + Args: max_wavelength: The maximum angular wavelength of the sine/cosine curves, as described in Attention is All You Need. Defaults to @@ -65,10 +70,25 @@ def __init__( def call(self, inputs): # TODO(jbischof): replace `hidden_size` with`hidden_dim` for consistency # with other layers. - input_shape = tf.shape(inputs) - # length of sequence is the second last dimension of the inputs - seq_length = input_shape[-2] - hidden_size = input_shape[-1] + if isinstance(inputs, tf.RaggedTensor): + bounding_shape = inputs.bounding_shape() + position_embeddings = self._compute_trim_and_broadcast_position_embeddings( + bounding_shape, + ) + # then apply row lengths to recreate the same ragged shape as inputs + return tf.RaggedTensor.from_tensor( + position_embeddings, + inputs.nested_row_lengths(), + ) + else: + return self._compute_trim_and_broadcast_position_embeddings( + tf.shape(inputs), + ) + + + def _compute_trim_and_broadcast_position_embeddings(self, shape): + seq_length = shape[-2] + hidden_size = shape[-1] position = tf.cast(tf.range(seq_length), self.compute_dtype) min_freq = tf.cast(1 / self.max_wavelength, dtype=self.compute_dtype) timescales = tf.pow( @@ -82,10 +102,10 @@ def call(self, inputs): sin_mask = 1 - cos_mask # embedding shape is [seq_length, hidden_size] positional_encodings = ( - tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask + tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask ) - return tf.broadcast_to(positional_encodings, input_shape) + return tf.broadcast_to(positional_encodings, shape) def get_config(self): config = super().get_config() From a3f01fe281931e71b7278074daf00e77f168d4e2 Mon Sep 17 00:00:00 2001 From: apupneja Date: Tue, 14 Feb 2023 23:58:02 +0530 Subject: [PATCH 2/3] formatting the code --- keras_nlp/layers/sine_position_encoding.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/keras_nlp/layers/sine_position_encoding.py b/keras_nlp/layers/sine_position_encoding.py index 5d29eb51cd..3fabb5426c 100644 --- a/keras_nlp/layers/sine_position_encoding.py +++ b/keras_nlp/layers/sine_position_encoding.py @@ -72,8 +72,10 @@ def call(self, inputs): # with other layers. if isinstance(inputs, tf.RaggedTensor): bounding_shape = inputs.bounding_shape() - position_embeddings = self._compute_trim_and_broadcast_position_embeddings( - bounding_shape, + position_embeddings = ( + self._compute_trim_and_broadcast_position_embeddings( + bounding_shape, + ) ) # then apply row lengths to recreate the same ragged shape as inputs return tf.RaggedTensor.from_tensor( @@ -85,7 +87,6 @@ def call(self, inputs): tf.shape(inputs), ) - def _compute_trim_and_broadcast_position_embeddings(self, shape): seq_length = shape[-2] hidden_size = shape[-1] @@ -102,7 +103,7 @@ def _compute_trim_and_broadcast_position_embeddings(self, shape): sin_mask = 1 - cos_mask # embedding shape is [seq_length, hidden_size] positional_encodings = ( - tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask + tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask ) return tf.broadcast_to(positional_encodings, shape) From 21f2c5211012f83ba4c3ddc92ddd28cd543bb8ba Mon Sep 17 00:00:00 2001 From: apupneja Date: Thu, 16 Feb 2023 16:01:56 +0530 Subject: [PATCH 3/3] tests for ragged tensor --- .../layers/sine_position_encoding_test.py | 79 ++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/keras_nlp/layers/sine_position_encoding_test.py b/keras_nlp/layers/sine_position_encoding_test.py index c61f1c99ac..05aff73e2f 100644 --- a/keras_nlp/layers/sine_position_encoding_test.py +++ b/keras_nlp/layers/sine_position_encoding_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for Sinusoidal Positional encoding.""" - import tensorflow as tf from tensorflow import keras @@ -92,6 +91,84 @@ def test_output_correct_values(self): self.assertAllClose(output[0, 0, :], expected_encoding_position_0) self.assertAllClose(output[0, 3, :], expected_encoding_position_3) + def test_ragged_tensor_with_3_dimensions(self): + feature_size = 2 + test_layer = sine_position_encoding.SinePositionEncoding() + # Create a 3-dimensional ragged input (the first dimension is implicit). + input_tensor = keras.Input( + shape=(None, feature_size), dtype=tf.float32, ragged=True + ) + output_tensor = test_layer(input_tensor) + model = keras.Model(input_tensor, output_tensor) + + input_data = tf.ragged.constant( + [ + [[1.0, 1.0], [1.0, 1.0]], + [], + [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0]], + ], + ragged_rank=1, + inner_shape=(2,), + ) + expected_output_data = tf.ragged.constant( + [ + [[0.0, 1.0], [0.84147096, 0.5403023]], + [], + [[0.0, 1.0], [0.84147096, 0.5403023], [0.9092974, -0.41614684]], + [[0.0, 1.0]], + ], + ragged_rank=1, + inner_shape=(2,), + ) + output_data = model.predict(input_data) + self.assertAllClose(output_data, expected_output_data) + + def test_ragged_tensor_with_4_dimensions(self): + feature_size = 2 + test_layer = sine_position_encoding.SinePositionEncoding() + # Create a 4-dimensional ragged input (the first dimension is implicit). + input_tensor = keras.Input( + shape=(None, None, feature_size), dtype=tf.float32, ragged=True + ) + output_tensor = test_layer(input_tensor) + model = keras.Model(input_tensor, output_tensor) + + input_data = tf.ragged.constant( + [ + [ + [[1.0, 1.0], [1.0, 1.0]], + [], + ], + [ + [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], + [[1.0, 1.0]], + ], + ], + ragged_rank=2, + inner_shape=(2,), + ) + expected_output_data = tf.ragged.constant( + [ + [ + [[0.0, 1.0], [0.84147096, 0.5403023]], + [], + ], + [ + [ + [0.0, 1.0], + [0.84147096, 0.5403023], + [0.9092974, -0.41614684], + ], + [[0.0, 1.0]], + ], + ], + ragged_rank=2, + inner_shape=(2,), + ) + output_data = model.predict(input_data) + self.assertAllClose(output_data, expected_output_data) + def test_get_config_and_from_config(self): pos_encoding = sine_position_encoding.SinePositionEncoding( max_wavelength=1000,