diff --git a/tests/transformers/tests/generation/test_utils.py b/tests/transformers/tests/generation/test_utils.py index b9dbd5baac..2c6ad23d22 100644 --- a/tests/transformers/tests/generation/test_utils.py +++ b/tests/transformers/tests/generation/test_utils.py @@ -84,6 +84,11 @@ class GenerationTesterMixin: all_generative_model_classes = () input_name = "input_ids" + def _update_default_model_kwargs(self, model_kwargs): + model_kwargs["limit_hpu_graphs"] = False + model_kwargs["reuse_cache"] = False + model_kwargs["bucket_size"] = -1 + def _get_input_ids_and_config(self, batch_size=2): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict[self.input_name] @@ -268,6 +273,7 @@ def _greedy_generate( kwargs = {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + output_generate = model.generate( input_ids, do_sample=False, @@ -294,6 +300,7 @@ def _greedy_generate( with torch.no_grad(): model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + self._update_default_model_kwargs(model_kwargs) output_greedy = model.greedy_search( input_ids, max_length=max_length, @@ -423,6 +430,7 @@ def _beam_search_generate( with torch.no_grad(): model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + self._update_default_model_kwargs(model_kwargs) output_beam_search = model.beam_search( input_ids.repeat_interleave(beam_scorer.num_beams, dim=0), beam_scorer, @@ -552,6 +560,8 @@ def _group_beam_search_generate( with torch.no_grad(): model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + self._update_default_model_kwargs(model_kwargs) + output_group_beam_search = model.group_beam_search( input_ids.repeat_interleave(beam_scorer.num_beams, dim=0), beam_scorer, @@ -615,6 +625,7 @@ def _constrained_beam_search_generate( with torch.no_grad(): model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + self._update_default_model_kwargs(model_kwargs) output_group_beam_search = model.constrained_beam_search( input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0), constrained_beam_scorer,