From e80be0f8697ceb9fec570e585389f88ece335cf7 Mon Sep 17 00:00:00 2001 From: Bhargav Date: Thu, 22 Feb 2024 13:54:30 +0200 Subject: [PATCH] Fixing tests by making static_shapes False --- optimum/habana/transformers/generation/utils.py | 11 +---------- tests/transformers/tests/generation/test_utils.py | 10 +++++++++- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index a97e911b99..599e33aa2b 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -211,15 +211,6 @@ def _update_model_kwargs_for_generation( model_kwargs["attention_mask"] = attention_mask else: # update decoder attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - if token_idx is not None: - attention_mask.index_fill_(1, token_idx, 1) - else: - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - model_kwargs["attention_mask"] = attention_mask if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] if token_idx is not None: @@ -1391,7 +1382,7 @@ def greedy_search( bucket_size = model_kwargs.get("bucket_size", -1) prev_idx = -1 # avoiding calculate cache_idx when its value is not changing - bucket_internal = model_kwargs["bucket_internal"] + bucket_internal = model_kwargs.get("bucket_internal", False) reduce_recompile = model_kwargs.get("reduce_recompile", False) prompt_len = input_ids.shape[-1] diff --git a/tests/transformers/tests/generation/test_utils.py b/tests/transformers/tests/generation/test_utils.py index 95568ac54e..21250ab169 100644 --- a/tests/transformers/tests/generation/test_utils.py +++ b/tests/transformers/tests/generation/test_utils.py @@ -254,6 +254,10 @@ def _get_encoder_outputs( attention_mask = None return encoder_outputs, input_ids, attention_mask + @staticmethod + def _get_static_shapes(): + return False + def _greedy_generate( self, model, @@ -277,7 +281,7 @@ def _greedy_generate( kwargs = {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} - + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, @@ -337,6 +341,7 @@ def _sample_generate( torch.manual_seed(0) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} self._update_default_model_kwargs(model_kwargs) + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=True, @@ -406,6 +411,7 @@ def _beam_search_generate( ): model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} self._update_default_model_kwargs(model_kwargs) + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, @@ -603,6 +609,7 @@ def _constrained_beam_search_generate( ): model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} self._update_default_model_kwargs(model_kwargs) + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False, @@ -679,6 +686,7 @@ def _contrastive_generate( kwargs = {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} self._update_default_model_kwargs(model_kwargs) + model.generation_config.static_shapes = self._get_static_shapes() output_generate = model.generate( input_ids, do_sample=False,