diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 2faeee9b8a5e..1b7b731a410c 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -284,6 +284,8 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi fx_compatible = False test_pruning = False test_missing_keys = False + # Needs higher percentages after model tester's vocab_size is changed to 200 (PR #21222) + model_split_percents = [0.8, 0.9] input_name = "input_features" @@ -727,7 +729,17 @@ def _create_and_check_torchscript(self, config, inputs_dict): input_features = inputs["input_features"] decoder_input_ids = inputs["decoder_input_ids"] decoder_attention_mask = inputs["decoder_attention_mask"] - traced_model = torch.jit.trace(model, (input_features, decoder_input_ids, decoder_attention_mask)) + # prepare `attention_mask` with shape (batch_size, sequence_length) + attention_mask = torch.ones( + input_features.shape[0], + input_features.shape[-1], + device=input_features.device, + dtype=input_features.dtype, + ) + traced_model = torch.jit.trace( + model, (input_features, attention_mask, decoder_input_ids, decoder_attention_mask) + ) + except RuntimeError: self.fail("Couldn't trace module.")