Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 12 additions & 38 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from huggingface_hub import hf_hub_download
from parameterized import parameterized

import transformers
from transformers import WhisperConfig
from transformers.testing_utils import (
is_flaky,
Expand All @@ -41,7 +40,7 @@
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_torch_xpu_available, is_torchaudio_available
from transformers.utils import is_torch_available, is_torch_xpu_available, is_torchaudio_available
from transformers.utils.import_utils import is_datasets_available

from ...generation.test_utils import GenerationTesterMixin
Expand Down Expand Up @@ -1432,33 +1431,22 @@ def test_generate_compilation_all_outputs(self):
@require_torch
@require_torchaudio
class WhisperModelIntegrationTests(unittest.TestCase):
def setUp(self):
self._unpatched_generation_mixin_generate = transformers.GenerationMixin.generate
Copy link
Contributor Author

@gante gante May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was only used in one test, and that test should use the base generate. This overcomplicates things.

idk why it was added in the first place 🤔

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is linked to #29312, but the change here is only test. So good for me if it works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is implicitly tested: the output checked in the test is different if num_beams is not respected :)

(just like in all other beam search integration tests: we check the output, which is sensible enough to detect bad usage of flags)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's definitely way over the top for what it tried. So yea let's keep it simple.


def tearDown(self):
transformers.GenerationMixin.generate = self._unpatched_generation_mixin_generate

@cached_property
def default_processor(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused

return WhisperProcessor.from_pretrained("openai/whisper-base")
_dataset = None

@classmethod
def _load_dataset(cls):
# Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process.
if cls._dataset is None:
cls._dataset = datasets.load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)

def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
self._load_dataset()
ds = self._dataset
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]

return [x["array"] for x in speech_samples]

def _patch_generation_mixin_generate(self, check_args_fn=None):
test = self

def generate(self, *args, **kwargs):
if check_args_fn is not None:
check_args_fn(*args, **kwargs)
return test._unpatched_generation_mixin_generate(self, *args, **kwargs)

transformers.GenerationMixin.generate = generate

@slow
def test_tiny_logits_librispeech(self):
torch_device = "cpu"
Expand Down Expand Up @@ -1586,8 +1574,6 @@ def test_large_logits_librispeech(self):

@slow
def test_tiny_en_generation(self):
torch_device = "cpu"
set_seed(0)
Comment on lines -1589 to -1590
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you run with flakefinder to see if we need seed or not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RUN_SLOW=1 py.test tests/models/whisper/test_modeling_whisper.py -k test_tiny_en_generation --flake-finder --flake-runs 100 yields no failures

(set_seed(0) comes from the original whisper commit. However, AFAIK, Whisper has no random components -- in fact, many output-checking tests in this file don't set a seed)

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device)
Expand All @@ -1605,8 +1591,6 @@ def test_tiny_en_generation(self):

@slow
def test_tiny_generation(self):
torch_device = "cpu"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
Expand All @@ -1623,8 +1607,6 @@ def test_tiny_generation(self):

@slow
def test_large_generation(self):
torch_device = "cpu"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
model.to(torch_device)
Expand All @@ -1643,7 +1625,6 @@ def test_large_generation(self):

@slow
def test_large_generation_multilingual(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
model.to(torch_device)
Expand Down Expand Up @@ -1710,8 +1691,6 @@ def test_large_batched_generation(self):

@slow
def test_large_batched_generation_multilingual(self):
torch_device = "cpu"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
model.to(torch_device)
Expand Down Expand Up @@ -2727,11 +2706,6 @@ def test_whisper_longform_single_batch_beam(self):
"renormalize_logits": True, # necessary to match OAI beam search implementation
}

def check_gen_kwargs(inputs, generation_config, *args, **kwargs):
self.assertEqual(generation_config.num_beams, gen_kwargs["num_beams"])

self._patch_generation_mixin_generate(check_args_fn=check_gen_kwargs)

torch.manual_seed(0)
result = model.generate(input_features, **gen_kwargs)
decoded = processor.batch_decode(result, skip_special_tokens=True)
Expand Down