Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions keras_nlp/layers/sine_position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ class SinePositionEncoding(keras.layers.Layer):
positional encoding the same size as the embedded token tensor, which
can be added directly to the embedded token tensor.

This layer optionally accepts `tf.RaggedTensor`s as inputs to process
batches of sequences of different lengths. The one ragged dimension must be
the dimension that corresponds to the sequence, that is, the penultimate
dimension.

Args:
max_wavelength: The maximum angular wavelength of the sine/cosine
curves, as described in Attention is All You Need. Defaults to
Expand Down Expand Up @@ -65,10 +70,26 @@ def __init__(
def call(self, inputs):
# TODO(jbischof): replace `hidden_size` with`hidden_dim` for consistency
# with other layers.
input_shape = tf.shape(inputs)
# length of sequence is the second last dimension of the inputs
seq_length = input_shape[-2]
hidden_size = input_shape[-1]
if isinstance(inputs, tf.RaggedTensor):
bounding_shape = inputs.bounding_shape()
position_embeddings = (
self._compute_trim_and_broadcast_position_embeddings(
bounding_shape,
)
)
# then apply row lengths to recreate the same ragged shape as inputs
return tf.RaggedTensor.from_tensor(
position_embeddings,
inputs.nested_row_lengths(),
)
else:
return self._compute_trim_and_broadcast_position_embeddings(
tf.shape(inputs),
)

def _compute_trim_and_broadcast_position_embeddings(self, shape):
seq_length = shape[-2]
hidden_size = shape[-1]
position = tf.cast(tf.range(seq_length), self.compute_dtype)
min_freq = tf.cast(1 / self.max_wavelength, dtype=self.compute_dtype)
timescales = tf.pow(
Expand All @@ -85,7 +106,7 @@ def call(self, inputs):
tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask
)

return tf.broadcast_to(positional_encodings, input_shape)
return tf.broadcast_to(positional_encodings, shape)

def get_config(self):
config = super().get_config()
Expand Down
79 changes: 78 additions & 1 deletion keras_nlp/layers/sine_position_encoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
"""Tests for Sinusoidal Positional encoding."""


import tensorflow as tf
from tensorflow import keras

Expand Down Expand Up @@ -92,6 +91,84 @@ def test_output_correct_values(self):
self.assertAllClose(output[0, 0, :], expected_encoding_position_0)
self.assertAllClose(output[0, 3, :], expected_encoding_position_3)

def test_ragged_tensor_with_3_dimensions(self):
feature_size = 2
test_layer = sine_position_encoding.SinePositionEncoding()
# Create a 3-dimensional ragged input (the first dimension is implicit).
input_tensor = keras.Input(
shape=(None, feature_size), dtype=tf.float32, ragged=True
)
output_tensor = test_layer(input_tensor)
model = keras.Model(input_tensor, output_tensor)

input_data = tf.ragged.constant(
[
[[1.0, 1.0], [1.0, 1.0]],
[],
[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]],
[[1.0, 1.0]],
],
ragged_rank=1,
inner_shape=(2,),
)
expected_output_data = tf.ragged.constant(
[
[[0.0, 1.0], [0.84147096, 0.5403023]],
[],
[[0.0, 1.0], [0.84147096, 0.5403023], [0.9092974, -0.41614684]],
[[0.0, 1.0]],
],
ragged_rank=1,
inner_shape=(2,),
)
output_data = model.predict(input_data)
self.assertAllClose(output_data, expected_output_data)

def test_ragged_tensor_with_4_dimensions(self):
feature_size = 2
test_layer = sine_position_encoding.SinePositionEncoding()
# Create a 4-dimensional ragged input (the first dimension is implicit).
input_tensor = keras.Input(
shape=(None, None, feature_size), dtype=tf.float32, ragged=True
)
output_tensor = test_layer(input_tensor)
model = keras.Model(input_tensor, output_tensor)

input_data = tf.ragged.constant(
[
[
[[1.0, 1.0], [1.0, 1.0]],
[],
],
[
[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]],
[[1.0, 1.0]],
],
],
ragged_rank=2,
inner_shape=(2,),
)
expected_output_data = tf.ragged.constant(
[
[
[[0.0, 1.0], [0.84147096, 0.5403023]],
[],
],
[
[
[0.0, 1.0],
[0.84147096, 0.5403023],
[0.9092974, -0.41614684],
],
[[0.0, 1.0]],
],
],
ragged_rank=2,
inner_shape=(2,),
)
output_data = model.predict(input_data)
self.assertAllClose(output_data, expected_output_data)

def test_get_config_and_from_config(self):
pos_encoding = sine_position_encoding.SinePositionEncoding(
max_wavelength=1000,
Expand Down