From 15367622b3c5bb088b6be7602cd8df036432a335 Mon Sep 17 00:00:00 2001 From: younesbelakda Date: Wed, 22 Feb 2023 08:57:19 +0000 Subject: [PATCH 1/3] fix bug --- src/transformers/models/gpt_neo/modeling_gpt_neo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 82bdf8cee408..820e08f5f989 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -587,7 +587,7 @@ def forward( output_shape = input_shape + (hidden_states.size(-1),) - presents = () if use_cache else None + presents = () if use_cache and not self.gradient_checkpointing else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): From 5f6fa2693e81bfc83725d66be15e252c6d5cffad Mon Sep 17 00:00:00 2001 From: edbeeching Date: Wed, 22 Feb 2023 09:26:01 +0000 Subject: [PATCH 2/3] forward contrib credits from discussions From 318af82d42e0b288e3de61b268bc5ca429b876fa Mon Sep 17 00:00:00 2001 From: younesbelakda Date: Wed, 22 Feb 2023 12:36:56 +0000 Subject: [PATCH 3/3] change logic --- .../models/gpt_neo/modeling_gpt_neo.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 820e08f5f989..3391c9f116b0 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -587,7 +587,14 @@ def forward( output_shape = input_shape + (hidden_states.size(-1),) - presents = () if use_cache and not self.gradient_checkpointing else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -595,11 +602,6 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False def create_custom_forward(module): def custom_forward(*inputs):