diff --git a/keras_nlp/layers/modeling/transformer_decoder.py b/keras_nlp/layers/modeling/transformer_decoder.py index 3a3cda3f21..cb7c779332 100644 --- a/keras_nlp/layers/modeling/transformer_decoder.py +++ b/keras_nlp/layers/modeling/transformer_decoder.py @@ -469,9 +469,11 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index, + ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/bloom/bloom_decoder.py b/keras_nlp/models/bloom/bloom_decoder.py index b3f8b80da7..c5dbb5135e 100644 --- a/keras_nlp/models/bloom/bloom_decoder.py +++ b/keras_nlp/models/bloom/bloom_decoder.py @@ -174,9 +174,11 @@ def _compute_attention_mask( batch_size, input_length, output_length, - 0 - if attention_cache_update_index is None - else attention_cache_update_index, + ( + 0 + if attention_cache_update_index is None + else attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py index ae646fb2b6..f80cafd528 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py @@ -211,9 +211,11 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index, + ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/llama/llama_decoder.py b/keras_nlp/models/llama/llama_decoder.py index 47bac478cc..2137831be1 100644 --- a/keras_nlp/models/llama/llama_decoder.py +++ b/keras_nlp/models/llama/llama_decoder.py @@ -172,9 +172,11 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index, + ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask)