diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a958c8c86a92..cb3ac0ff1d12 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3034,6 +3034,8 @@ def _beam_search( num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) if num_beams * batch_size != batch_beam_size: @@ -3437,6 +3439,8 @@ def _beam_sample( num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) # init attention / hidden states / scores tuples @@ -3795,6 +3799,8 @@ def _group_beam_search( device = input_ids.device batch_beam_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) if return_dict_in_generate and output_scores: @@ -4211,6 +4217,8 @@ def _constrained_beam_search( num_beams = constrained_beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) if num_beams * batch_size != batch_beam_size: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 99f6e84a3036..83c0758f462e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -717,6 +717,19 @@ def test_beam_sample_generate(self): ) self.assertTrue(output_generate.shape[-1] == max_length) + if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters): + input_embeds = model.get_input_embeddings()(input_ids) + beam_kwargs.update({"inputs_embeds": input_embeds}) + output_generate2 = self._beam_sample_generate( + model=model, + input_ids=None, + attention_mask=attention_mask, + max_length=max_length, + beam_kwargs=beam_kwargs, + logits_warper_kwargs=logits_warper_kwargs, + ) + + torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2) def test_beam_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: diff --git a/tests/models/biogpt/test_modeling_biogpt.py b/tests/models/biogpt/test_modeling_biogpt.py index 1055288e5c2d..58dd39e86a58 100644 --- a/tests/models/biogpt/test_modeling_biogpt.py +++ b/tests/models/biogpt/test_modeling_biogpt.py @@ -414,6 +414,10 @@ def test_biogpt_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + @unittest.skip("The `input_embeds` when fed don't produce the same results.") + def test_beam_sample_generate(self): + pass + @require_torch class BioGptModelIntegrationTest(unittest.TestCase):