From f55d3f474b91005d3516ac7a95fd6bd0506fb310 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 13 Feb 2024 01:27:38 -0800 Subject: [PATCH 1/2] Fix transformer decoder bug --- keras_nlp/layers/modeling/transformer_decoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/layers/modeling/transformer_decoder.py b/keras_nlp/layers/modeling/transformer_decoder.py index cb7c779332..15c245768c 100644 --- a/keras_nlp/layers/modeling/transformer_decoder.py +++ b/keras_nlp/layers/modeling/transformer_decoder.py @@ -416,10 +416,10 @@ def call( cache=cross_attention_cache, cache_update_index=cross_attention_cache_update_index, ) - if self_attention_cache is None: + if cross_attention_cache is None: x = attention_output else: - x, self_attention_cache = attention_output + x, cross_attention_cache = attention_output x = self._cross_attention_dropout(x) x = x + residual if not self.normalize_first: From 015507a5cca89fe16c76e79df1e7b491499625bc Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 13 Feb 2024 09:20:33 -0800 Subject: [PATCH 2/2] Fix bart and error with TransformerDecoder with cached cross attention --- .../modeling/transformer_decoder_test.py | 58 +++++++++++++++++-- 1 file changed, 54 insertions(+), 4 deletions(-) diff --git a/keras_nlp/layers/modeling/transformer_decoder_test.py b/keras_nlp/layers/modeling/transformer_decoder_test.py index 4757a0a45c..aa16d9ae5a 100644 --- a/keras_nlp/layers/modeling/transformer_decoder_test.py +++ b/keras_nlp/layers/modeling/transformer_decoder_test.py @@ -127,10 +127,7 @@ def test_mask_propagation_without_cross_attention(self): self.assertAllEqual(outputs._keras_mask, mask) def test_cache_call_is_correct(self): - batch_size = 2 - seq_len = 5 - num_heads = 2 - key_dim = 4 + batch_size, seq_len, num_heads, key_dim = 2, 5, 2, 4 hidden_dim = num_heads * key_dim input_shape = (batch_size, seq_len, hidden_dim) @@ -171,6 +168,59 @@ def call(outputs, cache): self.assertAllClose(output, no_loop_outputs) self.assertAllClose(output_cache, no_loop_cache) + def test_cache_call_is_correct_with_cross_attention(self): + batch_size, seq_len, num_heads, key_dim = 2, 5, 2, 4 + hidden_dim = num_heads * key_dim + + input_shape = (batch_size, seq_len, hidden_dim) + cache_shape = (batch_size, 2, seq_len, num_heads, key_dim) + decoder_sequence = random.uniform(shape=input_shape) + encoder_sequence = random.uniform(shape=input_shape) + empty_cache = ops.zeros(cache_shape) + outputs = ops.zeros_like(decoder_sequence) + + layer = TransformerDecoder( + intermediate_dim=4, + num_heads=num_heads, + ) + no_loop_outputs, no_loop_self_cache, no_loop_cross_cache = layer( + decoder_sequence, + encoder_sequence, + self_attention_cache=empty_cache, + self_attention_cache_update_index=0, + cross_attention_cache=empty_cache, + cross_attention_cache_update_index=0, + ) + + def loop_body(i, outputs, self_cache, cross_cache): + # Compute the rest tokens. + start, size = (0, i, 0), (batch_size, 1, hidden_dim) + next_input = ops.slice(decoder_sequence, start, size) + next_output, self_cache, cross_cache = layer( + decoder_sequence=next_input, + encoder_sequence=encoder_sequence, + self_attention_cache=self_cache, + self_attention_cache_update_index=i, + cross_attention_cache=cross_cache, + ) + outputs = ops.slice_update(outputs, start, next_output) + return i + 1, outputs, self_cache, cross_cache + + def call(outputs, self_cache, cross_cache): + _, outputs, self_cache, cross_cache = ops.while_loop( + cond=lambda i, outputs, self_cache, cross_cache: i < seq_len, + body=loop_body, + loop_vars=[0, outputs, self_cache, cross_cache], + ) + return outputs, self_cache, cross_cache + + output, self_cache, cross_cache = call( + outputs, empty_cache, no_loop_cross_cache + ) + self.assertAllClose(output, no_loop_outputs) + self.assertAllClose(self_cache, no_loop_self_cache) + self.assertAllClose(cross_cache, no_loop_cross_cache) + def test_different_feature_dimension_for_encoder_and_decoder_sequence(self): decoder = TransformerDecoder( intermediate_dim=4,