-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[whisper] small changes for faster tests #38236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,6 @@ | |
| from huggingface_hub import hf_hub_download | ||
| from parameterized import parameterized | ||
|
|
||
| import transformers | ||
| from transformers import WhisperConfig | ||
| from transformers.testing_utils import ( | ||
| is_flaky, | ||
|
|
@@ -41,7 +40,7 @@ | |
| slow, | ||
| torch_device, | ||
| ) | ||
| from transformers.utils import cached_property, is_torch_available, is_torch_xpu_available, is_torchaudio_available | ||
| from transformers.utils import is_torch_available, is_torch_xpu_available, is_torchaudio_available | ||
| from transformers.utils.import_utils import is_datasets_available | ||
|
|
||
| from ...generation.test_utils import GenerationTesterMixin | ||
|
|
@@ -1432,33 +1431,22 @@ def test_generate_compilation_all_outputs(self): | |
| @require_torch | ||
| @require_torchaudio | ||
| class WhisperModelIntegrationTests(unittest.TestCase): | ||
| def setUp(self): | ||
| self._unpatched_generation_mixin_generate = transformers.GenerationMixin.generate | ||
|
|
||
| def tearDown(self): | ||
| transformers.GenerationMixin.generate = self._unpatched_generation_mixin_generate | ||
|
|
||
| @cached_property | ||
| def default_processor(self): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unused |
||
| return WhisperProcessor.from_pretrained("openai/whisper-base") | ||
| _dataset = None | ||
|
|
||
| @classmethod | ||
| def _load_dataset(cls): | ||
| # Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process. | ||
| if cls._dataset is None: | ||
| cls._dataset = datasets.load_dataset( | ||
| "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" | ||
| ) | ||
|
|
||
| def _load_datasamples(self, num_samples): | ||
| ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | ||
| # automatic decoding with librispeech | ||
| self._load_dataset() | ||
| ds = self._dataset | ||
| speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] | ||
|
|
||
| return [x["array"] for x in speech_samples] | ||
|
|
||
| def _patch_generation_mixin_generate(self, check_args_fn=None): | ||
| test = self | ||
|
|
||
| def generate(self, *args, **kwargs): | ||
| if check_args_fn is not None: | ||
| check_args_fn(*args, **kwargs) | ||
| return test._unpatched_generation_mixin_generate(self, *args, **kwargs) | ||
|
|
||
| transformers.GenerationMixin.generate = generate | ||
|
|
||
| @slow | ||
| def test_tiny_logits_librispeech(self): | ||
| torch_device = "cpu" | ||
|
|
@@ -1586,8 +1574,6 @@ def test_large_logits_librispeech(self): | |
|
|
||
| @slow | ||
| def test_tiny_en_generation(self): | ||
| torch_device = "cpu" | ||
| set_seed(0) | ||
|
Comment on lines
-1589
to
-1590
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you run with flakefinder to see if we need seed or not?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
( |
||
| processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") | ||
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") | ||
| model.to(torch_device) | ||
|
|
@@ -1605,8 +1591,6 @@ def test_tiny_en_generation(self): | |
|
|
||
| @slow | ||
| def test_tiny_generation(self): | ||
| torch_device = "cpu" | ||
| set_seed(0) | ||
| processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") | ||
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") | ||
| model.to(torch_device) | ||
|
|
@@ -1623,8 +1607,6 @@ def test_tiny_generation(self): | |
|
|
||
| @slow | ||
| def test_large_generation(self): | ||
| torch_device = "cpu" | ||
| set_seed(0) | ||
| processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") | ||
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") | ||
| model.to(torch_device) | ||
|
|
@@ -1643,7 +1625,6 @@ def test_large_generation(self): | |
|
|
||
| @slow | ||
| def test_large_generation_multilingual(self): | ||
| set_seed(0) | ||
| processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") | ||
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") | ||
| model.to(torch_device) | ||
|
|
@@ -1710,8 +1691,6 @@ def test_large_batched_generation(self): | |
|
|
||
| @slow | ||
| def test_large_batched_generation_multilingual(self): | ||
| torch_device = "cpu" | ||
| set_seed(0) | ||
| processor = WhisperProcessor.from_pretrained("openai/whisper-large") | ||
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large") | ||
| model.to(torch_device) | ||
|
|
@@ -2727,11 +2706,6 @@ def test_whisper_longform_single_batch_beam(self): | |
| "renormalize_logits": True, # necessary to match OAI beam search implementation | ||
| } | ||
|
|
||
| def check_gen_kwargs(inputs, generation_config, *args, **kwargs): | ||
| self.assertEqual(generation_config.num_beams, gen_kwargs["num_beams"]) | ||
|
|
||
| self._patch_generation_mixin_generate(check_args_fn=check_gen_kwargs) | ||
|
|
||
| torch.manual_seed(0) | ||
| result = model.generate(input_features, **gen_kwargs) | ||
| decoded = processor.batch_decode(result, skip_special_tokens=True) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was only used in one test, and that test should use the base
generate. This overcomplicates things.idk why it was added in the first place 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is linked to #29312, but the change here is only test. So good for me if it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is implicitly tested: the output checked in the test is different if
num_beamsis not respected :)(just like in all other beam search integration tests: we check the output, which is sensible enough to detect bad usage of flags)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's definitely way over the top for what it tried. So yea let's keep it simple.