diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 80773b8438a9..a631b465d1f3 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -703,17 +703,27 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def prepare_inputs_for_generation(self, input_ids, attention_mask, past_key_values=None, **kwargs): + def prepare_inputs_for_generation( + self, input_ids, attention_mask, inputs_embeds=None, past_key_values=None, **kwargs + ): # only last token for inputs_ids if past is defined in kwargs if past_key_values: input_ids = input_ids[:, -1].unsqueeze(-1) - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - } + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + } + ) + + return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx):