From 3d40e9bbeefa717e5e7c2f22d1e49e7a3cca1790 Mon Sep 17 00:00:00 2001 From: Karim Foda Date: Fri, 10 Mar 2023 11:48:11 +0100 Subject: [PATCH 1/3] Fix gradient checkpointing bug in Speech2Text --- .../speech_to_text/modeling_speech_to_text.py | 13 +++++++------ .../speech_to_text_2/modeling_speech_to_text_2.py | 13 +++++++------ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index eaef470e0822..ae400c7ca47f 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1024,6 +1024,13 @@ def forward( hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" " False`..." + ) + use_cache = False + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -1048,12 +1055,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" - " False`..." - ) - use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index e408974a5053..a217d13b1d39 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -632,6 +632,13 @@ def forward( hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" " False`..." + ) + use_cache = False + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -657,12 +664,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" - " False`..." - ) - use_cache = False def create_custom_forward(module): def custom_forward(*inputs): From 543fbc5f5a2b3853230d6fb2585e89ab1468e3d0 Mon Sep 17 00:00:00 2001 From: Karim Foda <35491698+KMFODA@users.noreply.github.com> Date: Fri, 10 Mar 2023 12:11:45 +0100 Subject: [PATCH 2/3] Update modeling_speech_to_text.py --- .../models/speech_to_text/modeling_speech_to_text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index ae400c7ca47f..a8c1a640f2ff 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1026,7 +1026,7 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: - logger.warning( + logger.warning_once( "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" " False`..." ) use_cache = False From 0d0596827169d8285c5c2c7cca4be00643adb0fe Mon Sep 17 00:00:00 2001 From: Karim Foda <35491698+KMFODA@users.noreply.github.com> Date: Fri, 10 Mar 2023 12:12:18 +0100 Subject: [PATCH 3/3] Update modeling_speech_to_text_2.py --- .../models/speech_to_text_2/modeling_speech_to_text_2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index a217d13b1d39..319589eab144 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -634,7 +634,7 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: - logger.warning( + logger.warning_once( "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" " False`..." ) use_cache = False