From 562f7417e4354f81a364ac58f836e7bc0be047d0 Mon Sep 17 00:00:00 2001 From: Somasree Majumder Date: Wed, 8 Mar 2023 13:47:51 +0530 Subject: [PATCH 1/4] fixing --- src/transformers/models/whisper/modeling_whisper.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 2e4dfa67f9c7..b12bafbd18c4 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1029,6 +1029,13 @@ def forward( hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" + " False`..." + ) + use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -1053,12 +1060,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" - " False`..." - ) - use_cache = False def create_custom_forward(module): def custom_forward(*inputs): From d8cf88f785c2c207cf81c2d8ce92f711d3d8b712 Mon Sep 17 00:00:00 2001 From: Somasree Majumder <56045049+soma2000-lang@users.noreply.github.com> Date: Wed, 8 Mar 2023 18:11:25 +0530 Subject: [PATCH 2/4] Update modeling_whisper.py --- src/transformers/models/whisper/modeling_whisper.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index b12bafbd18c4..1fe1c97ce7fb 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1030,12 +1030,11 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" - " False`..." - ) - use_cache = False + if use_cache: + logger.warning( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" " False`..." + ) + use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None From 2d6687ff460c317f1c16ef1aa125bd0f262a7d75 Mon Sep 17 00:00:00 2001 From: Somasree Majumder <56045049+soma2000-lang@users.noreply.github.com> Date: Wed, 8 Mar 2023 21:29:29 +0530 Subject: [PATCH 3/4] Update modeling_whisper.py --- src/transformers/models/whisper/modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 1fe1c97ce7fb..f083a66fac9a 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1031,7 +1031,7 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: - logger.warning( + logger.warning_once( "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" " False`..." ) use_cache = False From 674c1d3d2b28f418d6f91d4948bfc01652300ca0 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 8 Mar 2023 18:52:32 +0000 Subject: [PATCH 4/4] Update src/transformers/models/whisper/modeling_whisper.py --- src/transformers/models/whisper/modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f083a66fac9a..cefcac389507 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1032,7 +1032,7 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" " False`..." + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." ) use_cache = False # decoder layers