diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index a8c1a640f2ff..d08863f8353f 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1027,7 +1027,7 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" " False`..." + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." ) use_cache = False diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index e67c55c23b4e..975f483395be 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -1662,6 +1662,13 @@ def forward( deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -1691,11 +1698,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False def create_custom_forward(module): def custom_forward(*inputs):