From e571b043ce8757789d87eb5d3c55136424041ae8 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 1 Mar 2023 18:50:42 +0100 Subject: [PATCH 1/2] force on the same device --- src/transformers/models/whisper/modeling_whisper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 02e1c6c8433b..444f07a8b3a2 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -996,7 +996,9 @@ def forward( ) # embed positions - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length).to( + inputs_embeds.device + ) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) From 14a48c4b6684270a020bb199e078863f523d1c93 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 1 Mar 2023 20:16:01 +0100 Subject: [PATCH 2/2] fix tests --- .../models/whisper/modeling_whisper.py | 4 +--- tests/models/whisper/test_modeling_whisper.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 444f07a8b3a2..02e1c6c8433b 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -996,9 +996,7 @@ def forward( ) # embed positions - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length).to( - inputs_embeds.device - ) + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 2faeee9b8a5e..1b7b731a410c 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -284,6 +284,8 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi fx_compatible = False test_pruning = False test_missing_keys = False + # Needs higher percentages after model tester's vocab_size is changed to 200 (PR #21222) + model_split_percents = [0.8, 0.9] input_name = "input_features" @@ -727,7 +729,17 @@ def _create_and_check_torchscript(self, config, inputs_dict): input_features = inputs["input_features"] decoder_input_ids = inputs["decoder_input_ids"] decoder_attention_mask = inputs["decoder_attention_mask"] - traced_model = torch.jit.trace(model, (input_features, decoder_input_ids, decoder_attention_mask)) + # prepare `attention_mask` with shape (batch_size, sequence_length) + attention_mask = torch.ones( + input_features.shape[0], + input_features.shape[-1], + device=input_features.device, + dtype=input_features.dtype, + ) + traced_model = torch.jit.trace( + model, (input_features, attention_mask, decoder_input_ids, decoder_attention_mask) + ) + except RuntimeError: self.fail("Couldn't trace module.")