diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index c5c23b8f78..b963270fdd 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -199,11 +199,6 @@ def _update_model_kwargs_for_generation( model_kwargs["attention_mask"] = attention_mask else: # update decoder attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - if token_idx is not None: - attention_mask.index_fill_(1, token_idx, 1) - model_kwargs["attention_mask"] = attention_mask if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] if token_idx is not None: diff --git a/tests/transformers/tests/generation/test_utils.py b/tests/transformers/tests/generation/test_utils.py index 95568ac54e..21250ab169 100644 --- a/tests/transformers/tests/generation/test_utils.py +++ b/tests/transformers/tests/generation/test_utils.py @@ -254,6 +254,10 @@ def _get_encoder_outputs( attention_mask = None return encoder_outputs, input_ids, attention_mask + @staticmethod + def _get_static_shapes(): + return False + def _greedy_generate( self, model, @@ -277,7 +281,7 @@ def _greedy_generate( kwargs = {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, @@ -337,6 +341,7 @@ def _sample_generate( torch.manual_seed(0) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} self._update_default_model_kwargs(model_kwargs) + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=True, @@ -406,6 +411,7 @@ def _beam_search_generate( ): model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} self._update_default_model_kwargs(model_kwargs) + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, @@ -603,6 +609,7 @@ def _constrained_beam_search_generate( ): model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} self._update_default_model_kwargs(model_kwargs) + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, @@ -679,6 +686,7 @@ def _contrastive_generate( kwargs = {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} self._update_default_model_kwargs(model_kwargs) + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False,