Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion keras_nlp/src/layers/modeling/cached_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class CachedMultiHeadAttention(keras.layers.MultiHeadAttention):
`cache` (usually the index of the current token being processed
when running generation). If `cache_update_index=None` while `cache`
is set, the cache will not be updated.
training: a boolean indicating whether the layer should behave in
training mode or in inference mode.

Returns:
An `(attention_output, cache)` tuple. `attention_output` is the result
Expand All @@ -83,6 +85,7 @@ def call(
attention_mask=None,
cache=None,
cache_update_index=None,
training=None,
):
if (
hasattr(self, "_build_from_signature")
Expand Down Expand Up @@ -133,7 +136,9 @@ def call(
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_scores = self._dropout_layer(attention_scores)
attention_scores = self._dropout_layer(
attention_scores, training=training
)

attention_output = ops.einsum(
self._combine_equation, attention_scores, value
Expand Down
27 changes: 27 additions & 0 deletions keras_nlp/src/layers/modeling/cached_multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,30 @@ def call(outputs, cache):

self.assertAllClose(output, no_loop_outputs)
self.assertAllClose(output_cache, no_loop_cache)

def test_training_propagation(self):
batch_size = 2
seq_len = 5
num_heads = 2
key_dim = 4
hidden_dim = num_heads * key_dim

input_shape = (batch_size, seq_len, hidden_dim)
x = random.uniform(shape=input_shape)

layer = CachedMultiHeadAttention(
num_heads=num_heads,
key_dim=key_dim,
dropout=0.99999, # Zeros out the outputs after the dropout layer
)
outputs = layer(x, x, training=True)

# Custom computation with dropout rate sets to about 1.0
value = layer._value_dense(x)
attention_scores = ops.zeros((batch_size, num_heads, seq_len, seq_len))
attention_output = ops.einsum(
layer._combine_equation, attention_scores, value
)
attention_output = layer._output_dense(attention_output)

self.assertAllClose(outputs, attention_output, atol=1e-5)
6 changes: 4 additions & 2 deletions keras_nlp/src/layers/modeling/f_net_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,14 @@ def build(self, inputs_shape):
)
self.built = True

def call(self, inputs):
def call(self, inputs, training=None):
"""Forward pass of the FNetEncoder.

Args:
inputs: a Tensor. The input data to TransformerEncoder, should be
of shape [batch_size, sequence_length, feature_dim].
training: a boolean indicating whether the layer should behave in
training mode or in inference mode.

Returns:
A Tensor of the same shape as the `inputs`.
Expand All @@ -160,7 +162,7 @@ def add_and_norm(input1, input2, norm_layer):
def feed_forward(input):
x = self._intermediate_dense(input)
x = self._output_dense(x)
return self._output_dropout(x)
return self._output_dropout(x, training=training)

mixing_output = fourier_transform(inputs)

Expand Down
32 changes: 32 additions & 0 deletions keras_nlp/src/layers/modeling/f_net_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.src.backend import ops
from keras_nlp.src.backend import random
from keras_nlp.src.layers.modeling.f_net_encoder import FNetEncoder
from keras_nlp.src.tests.test_case import TestCase
Expand Down Expand Up @@ -42,3 +43,34 @@ def test_value_error_when_invalid_kernel_initializer(self):
dropout=0.5,
kernel_initializer="Invalid",
)

def test_training_propagation(self):
x = random.uniform(shape=(2, 4, 6))
layer = FNetEncoder(
intermediate_dim=4,
dropout=0.99999, # Zeros out the outputs after the dropout layer
)
outputs = layer(x, training=True)

# Custom computation with dropout rate sets to about 1.0
def fourier_transform(input):
# Apply FFT on the input and take the real part.
input_dtype = input.dtype
# FFT transforms do not support float16.
input = ops.cast(input, "float32")
real_in, imaginary_in = (input, ops.zeros_like(input))
real_out, _ = ops.fft2((real_in, imaginary_in))
return ops.cast(real_out, input_dtype)

def add_and_norm(input1, input2, norm_layer):
return norm_layer(input1 + input2)

mixing_output = fourier_transform(x)
mixing_output = add_and_norm(x, mixing_output, layer._mixing_layer_norm)
x = add_and_norm(
mixing_output,
ops.zeros_like(mixing_output),
layer._output_layer_norm,
)

self.assertAllClose(outputs, x, atol=1e-5)
12 changes: 9 additions & 3 deletions keras_nlp/src/layers/modeling/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def call(
cross_attention_cache=None,
cross_attention_cache_update_index=None,
use_causal_mask=True,
training=None,
):
"""Forward pass of the TransformerDecoder.

Expand Down Expand Up @@ -315,6 +316,9 @@ def call(
`None` (reuse a previously computed `cross_attention_cache`).
use_causal_mask: bool, defaults to `True`. If true, a causal mask
(masking out future input) is applied `on the decoder sequence.
training: a boolean indicating whether the layer should behave in
training mode or in inference mode.

Returns:
One of three things, depending on call arguments:
- `outputs`, if `self_attention_cache` is `None.
Expand Down Expand Up @@ -385,12 +389,13 @@ def call(
attention_mask=self_attention_mask,
cache=self_attention_cache,
cache_update_index=self_attention_cache_update_index,
training=training,
)
if self_attention_cache is None:
x = attention_output
else:
x, self_attention_cache = attention_output
x = self._self_attention_dropout(x)
x = self._self_attention_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
x = self._self_attention_layer_norm(x)
Expand All @@ -412,12 +417,13 @@ def call(
attention_mask=cross_attention_mask,
cache=cross_attention_cache,
cache_update_index=cross_attention_cache_update_index,
training=training,
)
if cross_attention_cache is None:
x = attention_output
else:
x, cross_attention_cache = attention_output
x = self._cross_attention_dropout(x)
x = self._cross_attention_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
x = self._cross_attention_layer_norm(x)
Expand All @@ -428,7 +434,7 @@ def call(
x = self._feedforward_layer_norm(x)
x = self._feedforward_intermediate_dense(x)
x = self._feedforward_output_dense(x)
x = self._feedforward_dropout(x)
x = self._feedforward_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
x = self._feedforward_layer_norm(x)
Expand Down
17 changes: 17 additions & 0 deletions keras_nlp/src/layers/modeling/transformer_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ def test_value_error_when_invalid_kernel_inititalizer(self):
kernel_initializer="Invalid",
)

def test_training_propagation(self):
decoder = TransformerDecoder(
intermediate_dim=4,
num_heads=2,
dropout=0.99999, # Zeros out the outputs after the dropout layer
)
decoder_sequence = random.uniform(shape=[1, 4, 6])
encoder_sequence = random.uniform(shape=[1, 4, 6])
outputs = decoder(decoder_sequence, encoder_sequence, training=True)

# Custom computation with dropout rates set to about 1.0
x = decoder_sequence
x = decoder._self_attention_layer_norm(x)
x = decoder._feedforward_layer_norm(x)

self.assertAllClose(outputs, x, atol=1e-5)

def test_mask_propagation(self):
decoder = TransformerDecoder(
intermediate_dim=4,
Expand Down
11 changes: 8 additions & 3 deletions keras_nlp/src/layers/modeling/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def build(self, inputs_shape):
)
self.built = True

def call(self, inputs, padding_mask=None, attention_mask=None):
def call(
self, inputs, padding_mask=None, attention_mask=None, training=None
):
"""Forward pass of the TransformerEncoder.

Args:
Expand All @@ -194,6 +196,8 @@ def call(self, inputs, padding_mask=None, attention_mask=None):
attention_mask: a boolean Tensor. Customized mask used to mask out
certain tokens. `attention_mask` should have shape
[batch_size, sequence_length, sequence_length].
training: a boolean indicating whether the layer should behave in
training mode or in inference mode.

Returns:
A Tensor of the same shape as the `inputs`.
Expand All @@ -213,8 +217,9 @@ def call(self, inputs, padding_mask=None, attention_mask=None):
query=x,
value=x,
attention_mask=self_attention_mask,
training=training,
)
x = self._self_attention_dropout(x)
x = self._self_attention_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
x = self._self_attention_layer_norm(x)
Expand All @@ -225,7 +230,7 @@ def call(self, inputs, padding_mask=None, attention_mask=None):
x = self._feedforward_layer_norm(x)
x = self._feedforward_intermediate_dense(x)
x = self._feedforward_output_dense(x)
x = self._feedforward_dropout(x)
x = self._feedforward_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
x = self._feedforward_layer_norm(x)
Expand Down
16 changes: 16 additions & 0 deletions keras_nlp/src/layers/modeling/transformer_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,22 @@ def test_value_error_when_invalid_kernel_inititalizer(self):
kernel_initializer="Invalid",
)

def test_training_propagation(self):
encoder = TransformerEncoder(
intermediate_dim=4,
num_heads=2,
dropout=0.99999, # Zeros out the outputs after the dropout layer
)
inputs = random.uniform(shape=[1, 4, 6])
outputs = encoder(inputs, training=True)

# Custom computation with dropout rates set to about 1.0
x = inputs
x = encoder._self_attention_layer_norm(x)
x = encoder._feedforward_layer_norm(x)

self.assertAllClose(outputs, x, atol=1e-5)

def test_mask_propagation(self):
encoder = TransformerEncoder(
intermediate_dim=4,
Expand Down