Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions keras_nlp/layers/sine_position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ class SinePositionEncoding(keras.layers.Layer):
positional encoding the same size as the embedded token tensor, which
can be added directly to the embedded token tensor.

This layer optionally accepts `tf.RaggedTensor`s as inputs to process
batches of sequences of different lengths. The one ragged dimension must be
the dimension that corresponds to the sequence, that is, the penultimate
dimension.

Args:
max_wavelength: The maximum angular wavelength of the sine/cosine
curves, as described in Attention is All You Need. Defaults to
Expand Down Expand Up @@ -65,10 +70,26 @@ def __init__(
def call(self, inputs):
# TODO(jbischof): replace `hidden_size` with`hidden_dim` for consistency
# with other layers.
input_shape = tf.shape(inputs)
# length of sequence is the second last dimension of the inputs
seq_length = input_shape[-2]
hidden_size = input_shape[-1]
if isinstance(inputs, tf.RaggedTensor):
bounding_shape = inputs.bounding_shape()
position_embeddings = (
self._compute_trim_and_broadcast_position_embeddings(
bounding_shape,
)
)
# then apply row lengths to recreate the same ragged shape as inputs
return tf.RaggedTensor.from_tensor(
position_embeddings,
inputs.nested_row_lengths(),
)
else:
return self._compute_trim_and_broadcast_position_embeddings(
tf.shape(inputs),
)

def _compute_trim_and_broadcast_position_embeddings(self, shape):
seq_length = shape[-2]
hidden_size = shape[-1]
position = tf.cast(tf.range(seq_length), self.compute_dtype)
min_freq = tf.cast(1 / self.max_wavelength, dtype=self.compute_dtype)
timescales = tf.pow(
Expand All @@ -85,7 +106,7 @@ def call(self, inputs):
tf.sin(angles) * sin_mask + tf.cos(angles) * cos_mask
)

return tf.broadcast_to(positional_encodings, input_shape)
return tf.broadcast_to(positional_encodings, shape)

def get_config(self):
config = super().get_config()
Expand Down