diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 0cdd80dd7411..ab07987315a4 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -140,10 +140,6 @@ def test_greedy_generate(self): # check `generate()` and `greedy_search()` are equal kwargs = {} if model.config.is_encoder_decoder: - encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( - model, input_ids, attention_mask - ) - kwargs["encoder_outputs"] = encoder_outputs max_length = 4 output_ids_generate = model.generate( @@ -154,6 +150,13 @@ def test_greedy_generate(self): max_length=max_length, **logits_process_kwargs, ) + + if model.config.is_encoder_decoder: + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( + model, input_ids, attention_mask + ) + kwargs["encoder_outputs"] = encoder_outputs + with torch.no_grad(): output_ids_greedy = model.greedy_search( input_ids,