From b851d5a944bf0ca32d6c096017f29276870e7837 Mon Sep 17 00:00:00 2001 From: Karim Foda Date: Fri, 10 Mar 2023 11:57:02 +0100 Subject: [PATCH 1/4] Fix gradient checkpointing bug in Speecht5 --- .../speech_to_text/modeling_speech_to_text.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index eaef470e0822..ae400c7ca47f 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1024,6 +1024,13 @@ def forward( hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" " False`..." + ) + use_cache = False + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -1048,12 +1055,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" - " False`..." - ) - use_cache = False def create_custom_forward(module): def custom_forward(*inputs): From 35dda5e3329169dcac5c0366dc74d9bb0d00d4cc Mon Sep 17 00:00:00 2001 From: Karim Foda <35491698+KMFODA@users.noreply.github.com> Date: Fri, 10 Mar 2023 12:12:38 +0100 Subject: [PATCH 2/4] Update modeling_speech_to_text.py --- .../models/speech_to_text/modeling_speech_to_text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index ae400c7ca47f..a8c1a640f2ff 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1026,7 +1026,7 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: - logger.warning( + logger.warning_once( "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" " False`..." ) use_cache = False From 4da5f4baaf4eddb71787ba4dbbbd3791953e481a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 10 Mar 2023 11:31:47 +0000 Subject: [PATCH 3/4] Update src/transformers/models/speech_to_text/modeling_speech_to_text.py --- .../models/speech_to_text/modeling_speech_to_text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index a8c1a640f2ff..d08863f8353f 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1027,7 +1027,7 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" " False`..." + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." ) use_cache = False From 37562679a39b2d0f5f2df4470aa9387c82d27953 Mon Sep 17 00:00:00 2001 From: Karim Foda Date: Fri, 10 Mar 2023 13:14:30 +0100 Subject: [PATCH 4/4] Fix change errors --- .../models/speecht5/modeling_speecht5.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index e67c55c23b4e..975f483395be 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -1662,6 +1662,13 @@ def forward( deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -1691,11 +1698,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False def create_custom_forward(module): def custom_forward(*inputs):