diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index fdab541263bb..eefcdccb461e 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -697,7 +697,9 @@ def forward( attentions=outputs.attentions, ) - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): input_shape = input_ids.shape # cut decoder_input_ids if past is used @@ -716,12 +718,21 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "past_key_values": past_key_values, - } + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + ) + + return model_inputs def _reorder_cache(self, past_key_values, beam_idx): reordered_past = ()