From c0fc7600fa90bd22d09a4e8c517cbf53fed399c7 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Fri, 26 Jan 2024 16:58:40 -0800 Subject: [PATCH] Update black formatting --- keras_nlp/layers/modeling/transformer_decoder.py | 8 +++++--- keras_nlp/models/bloom/bloom_decoder.py | 8 +++++--- keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py | 8 +++++--- keras_nlp/models/llama/llama_decoder.py | 8 +++++--- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/keras_nlp/layers/modeling/transformer_decoder.py b/keras_nlp/layers/modeling/transformer_decoder.py index 3a3cda3f21..cb7c779332 100644 --- a/keras_nlp/layers/modeling/transformer_decoder.py +++ b/keras_nlp/layers/modeling/transformer_decoder.py @@ -469,9 +469,11 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index, + ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/bloom/bloom_decoder.py b/keras_nlp/models/bloom/bloom_decoder.py index b3f8b80da7..c5dbb5135e 100644 --- a/keras_nlp/models/bloom/bloom_decoder.py +++ b/keras_nlp/models/bloom/bloom_decoder.py @@ -174,9 +174,11 @@ def _compute_attention_mask( batch_size, input_length, output_length, - 0 - if attention_cache_update_index is None - else attention_cache_update_index, + ( + 0 + if attention_cache_update_index is None + else attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py index ae646fb2b6..f80cafd528 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py @@ -211,9 +211,11 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index, + ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/llama/llama_decoder.py b/keras_nlp/models/llama/llama_decoder.py index 47bac478cc..2137831be1 100644 --- a/keras_nlp/models/llama/llama_decoder.py +++ b/keras_nlp/models/llama/llama_decoder.py @@ -172,9 +172,11 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index, + ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask)