diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 65169322adc7..fd9286cebfb4 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -697,7 +697,11 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non if past and past[0] is not None: input_ids = input_ids[:, -1:] - return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past or model_kwargs.get("past_key_values"), + } def _reorder_cache(self, past, beam_idx): reordered_past = ()