diff --git a/keras_nlp/src/layers/modeling/cached_multi_head_attention.py b/keras_nlp/src/layers/modeling/cached_multi_head_attention.py index 13341f4a7d..d4f08c8d77 100644 --- a/keras_nlp/src/layers/modeling/cached_multi_head_attention.py +++ b/keras_nlp/src/layers/modeling/cached_multi_head_attention.py @@ -65,6 +65,8 @@ class CachedMultiHeadAttention(keras.layers.MultiHeadAttention): `cache` (usually the index of the current token being processed when running generation). If `cache_update_index=None` while `cache` is set, the cache will not be updated. + training: a boolean indicating whether the layer should behave in + training mode or in inference mode. Returns: An `(attention_output, cache)` tuple. `attention_output` is the result @@ -83,6 +85,7 @@ def call( attention_mask=None, cache=None, cache_update_index=None, + training=None, ): if ( hasattr(self, "_build_from_signature") @@ -133,7 +136,9 @@ def call( attention_scores = self._masked_softmax( attention_scores, attention_mask ) - attention_scores = self._dropout_layer(attention_scores) + attention_scores = self._dropout_layer( + attention_scores, training=training + ) attention_output = ops.einsum( self._combine_equation, attention_scores, value diff --git a/keras_nlp/src/layers/modeling/cached_multi_head_attention_test.py b/keras_nlp/src/layers/modeling/cached_multi_head_attention_test.py index c44096593b..f90983de23 100644 --- a/keras_nlp/src/layers/modeling/cached_multi_head_attention_test.py +++ b/keras_nlp/src/layers/modeling/cached_multi_head_attention_test.py @@ -91,3 +91,30 @@ def call(outputs, cache): self.assertAllClose(output, no_loop_outputs) self.assertAllClose(output_cache, no_loop_cache) + + def test_training_propagation(self): + batch_size = 2 + seq_len = 5 + num_heads = 2 + key_dim = 4 + hidden_dim = num_heads * key_dim + + input_shape = (batch_size, seq_len, hidden_dim) + x = random.uniform(shape=input_shape) + + layer = CachedMultiHeadAttention( + num_heads=num_heads, + key_dim=key_dim, + dropout=0.99999, # Zeros out the outputs after the dropout layer + ) + outputs = layer(x, x, training=True) + + # Custom computation with dropout rate sets to about 1.0 + value = layer._value_dense(x) + attention_scores = ops.zeros((batch_size, num_heads, seq_len, seq_len)) + attention_output = ops.einsum( + layer._combine_equation, attention_scores, value + ) + attention_output = layer._output_dense(attention_output) + + self.assertAllClose(outputs, attention_output, atol=1e-5) diff --git a/keras_nlp/src/layers/modeling/f_net_encoder.py b/keras_nlp/src/layers/modeling/f_net_encoder.py index 9cb0906c16..ed78c423ee 100644 --- a/keras_nlp/src/layers/modeling/f_net_encoder.py +++ b/keras_nlp/src/layers/modeling/f_net_encoder.py @@ -134,12 +134,14 @@ def build(self, inputs_shape): ) self.built = True - def call(self, inputs): + def call(self, inputs, training=None): """Forward pass of the FNetEncoder. Args: inputs: a Tensor. The input data to TransformerEncoder, should be of shape [batch_size, sequence_length, feature_dim]. + training: a boolean indicating whether the layer should behave in + training mode or in inference mode. Returns: A Tensor of the same shape as the `inputs`. @@ -160,7 +162,7 @@ def add_and_norm(input1, input2, norm_layer): def feed_forward(input): x = self._intermediate_dense(input) x = self._output_dense(x) - return self._output_dropout(x) + return self._output_dropout(x, training=training) mixing_output = fourier_transform(inputs) diff --git a/keras_nlp/src/layers/modeling/f_net_encoder_test.py b/keras_nlp/src/layers/modeling/f_net_encoder_test.py index c6d086a1a7..009ce9dd8e 100644 --- a/keras_nlp/src/layers/modeling/f_net_encoder_test.py +++ b/keras_nlp/src/layers/modeling/f_net_encoder_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.src.backend import ops from keras_nlp.src.backend import random from keras_nlp.src.layers.modeling.f_net_encoder import FNetEncoder from keras_nlp.src.tests.test_case import TestCase @@ -42,3 +43,34 @@ def test_value_error_when_invalid_kernel_initializer(self): dropout=0.5, kernel_initializer="Invalid", ) + + def test_training_propagation(self): + x = random.uniform(shape=(2, 4, 6)) + layer = FNetEncoder( + intermediate_dim=4, + dropout=0.99999, # Zeros out the outputs after the dropout layer + ) + outputs = layer(x, training=True) + + # Custom computation with dropout rate sets to about 1.0 + def fourier_transform(input): + # Apply FFT on the input and take the real part. + input_dtype = input.dtype + # FFT transforms do not support float16. + input = ops.cast(input, "float32") + real_in, imaginary_in = (input, ops.zeros_like(input)) + real_out, _ = ops.fft2((real_in, imaginary_in)) + return ops.cast(real_out, input_dtype) + + def add_and_norm(input1, input2, norm_layer): + return norm_layer(input1 + input2) + + mixing_output = fourier_transform(x) + mixing_output = add_and_norm(x, mixing_output, layer._mixing_layer_norm) + x = add_and_norm( + mixing_output, + ops.zeros_like(mixing_output), + layer._output_layer_norm, + ) + + self.assertAllClose(outputs, x, atol=1e-5) diff --git a/keras_nlp/src/layers/modeling/transformer_decoder.py b/keras_nlp/src/layers/modeling/transformer_decoder.py index 649f048e70..3b46986aa5 100644 --- a/keras_nlp/src/layers/modeling/transformer_decoder.py +++ b/keras_nlp/src/layers/modeling/transformer_decoder.py @@ -279,6 +279,7 @@ def call( cross_attention_cache=None, cross_attention_cache_update_index=None, use_causal_mask=True, + training=None, ): """Forward pass of the TransformerDecoder. @@ -315,6 +316,9 @@ def call( `None` (reuse a previously computed `cross_attention_cache`). use_causal_mask: bool, defaults to `True`. If true, a causal mask (masking out future input) is applied `on the decoder sequence. + training: a boolean indicating whether the layer should behave in + training mode or in inference mode. + Returns: One of three things, depending on call arguments: - `outputs`, if `self_attention_cache` is `None. @@ -385,12 +389,13 @@ def call( attention_mask=self_attention_mask, cache=self_attention_cache, cache_update_index=self_attention_cache_update_index, + training=training, ) if self_attention_cache is None: x = attention_output else: x, self_attention_cache = attention_output - x = self._self_attention_dropout(x) + x = self._self_attention_dropout(x, training=training) x = x + residual if not self.normalize_first: x = self._self_attention_layer_norm(x) @@ -412,12 +417,13 @@ def call( attention_mask=cross_attention_mask, cache=cross_attention_cache, cache_update_index=cross_attention_cache_update_index, + training=training, ) if cross_attention_cache is None: x = attention_output else: x, cross_attention_cache = attention_output - x = self._cross_attention_dropout(x) + x = self._cross_attention_dropout(x, training=training) x = x + residual if not self.normalize_first: x = self._cross_attention_layer_norm(x) @@ -428,7 +434,7 @@ def call( x = self._feedforward_layer_norm(x) x = self._feedforward_intermediate_dense(x) x = self._feedforward_output_dense(x) - x = self._feedforward_dropout(x) + x = self._feedforward_dropout(x, training=training) x = x + residual if not self.normalize_first: x = self._feedforward_layer_norm(x) diff --git a/keras_nlp/src/layers/modeling/transformer_decoder_test.py b/keras_nlp/src/layers/modeling/transformer_decoder_test.py index f5c9a17c7e..0c8bbe7eae 100644 --- a/keras_nlp/src/layers/modeling/transformer_decoder_test.py +++ b/keras_nlp/src/layers/modeling/transformer_decoder_test.py @@ -103,6 +103,23 @@ def test_value_error_when_invalid_kernel_inititalizer(self): kernel_initializer="Invalid", ) + def test_training_propagation(self): + decoder = TransformerDecoder( + intermediate_dim=4, + num_heads=2, + dropout=0.99999, # Zeros out the outputs after the dropout layer + ) + decoder_sequence = random.uniform(shape=[1, 4, 6]) + encoder_sequence = random.uniform(shape=[1, 4, 6]) + outputs = decoder(decoder_sequence, encoder_sequence, training=True) + + # Custom computation with dropout rates set to about 1.0 + x = decoder_sequence + x = decoder._self_attention_layer_norm(x) + x = decoder._feedforward_layer_norm(x) + + self.assertAllClose(outputs, x, atol=1e-5) + def test_mask_propagation(self): decoder = TransformerDecoder( intermediate_dim=4, diff --git a/keras_nlp/src/layers/modeling/transformer_encoder.py b/keras_nlp/src/layers/modeling/transformer_encoder.py index d04f027e12..0484d33fa0 100644 --- a/keras_nlp/src/layers/modeling/transformer_encoder.py +++ b/keras_nlp/src/layers/modeling/transformer_encoder.py @@ -182,7 +182,9 @@ def build(self, inputs_shape): ) self.built = True - def call(self, inputs, padding_mask=None, attention_mask=None): + def call( + self, inputs, padding_mask=None, attention_mask=None, training=None + ): """Forward pass of the TransformerEncoder. Args: @@ -194,6 +196,8 @@ def call(self, inputs, padding_mask=None, attention_mask=None): attention_mask: a boolean Tensor. Customized mask used to mask out certain tokens. `attention_mask` should have shape [batch_size, sequence_length, sequence_length]. + training: a boolean indicating whether the layer should behave in + training mode or in inference mode. Returns: A Tensor of the same shape as the `inputs`. @@ -213,8 +217,9 @@ def call(self, inputs, padding_mask=None, attention_mask=None): query=x, value=x, attention_mask=self_attention_mask, + training=training, ) - x = self._self_attention_dropout(x) + x = self._self_attention_dropout(x, training=training) x = x + residual if not self.normalize_first: x = self._self_attention_layer_norm(x) @@ -225,7 +230,7 @@ def call(self, inputs, padding_mask=None, attention_mask=None): x = self._feedforward_layer_norm(x) x = self._feedforward_intermediate_dense(x) x = self._feedforward_output_dense(x) - x = self._feedforward_dropout(x) + x = self._feedforward_dropout(x, training=training) x = x + residual if not self.normalize_first: x = self._feedforward_layer_norm(x) diff --git a/keras_nlp/src/layers/modeling/transformer_encoder_test.py b/keras_nlp/src/layers/modeling/transformer_encoder_test.py index 99edf9b0ff..e6395193ce 100644 --- a/keras_nlp/src/layers/modeling/transformer_encoder_test.py +++ b/keras_nlp/src/layers/modeling/transformer_encoder_test.py @@ -83,6 +83,22 @@ def test_value_error_when_invalid_kernel_inititalizer(self): kernel_initializer="Invalid", ) + def test_training_propagation(self): + encoder = TransformerEncoder( + intermediate_dim=4, + num_heads=2, + dropout=0.99999, # Zeros out the outputs after the dropout layer + ) + inputs = random.uniform(shape=[1, 4, 6]) + outputs = encoder(inputs, training=True) + + # Custom computation with dropout rates set to about 1.0 + x = inputs + x = encoder._self_attention_layer_norm(x) + x = encoder._feedforward_layer_norm(x) + + self.assertAllClose(outputs, x, atol=1e-5) + def test_mask_propagation(self): encoder = TransformerEncoder( intermediate_dim=4,