Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
is_seqio_available,
is_soundfile_available,
is_spacy_available,
is_speech_available,
is_spqr_available,
is_sudachi_available,
is_sudachi_projection_available,
Expand Down Expand Up @@ -1476,6 +1477,13 @@ def require_tiktoken(test_case):
return unittest.skipUnless(is_tiktoken_available(), "test requires TikToken")(test_case)


def require_speech(test_case):
"""
Decorator marking a test that requires speech. These tests are skipped when speech isn't available.
"""
return unittest.skipUnless(is_speech_available(), "test requires speech")(test_case)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we have to decide what we really require.

def is_speech_available():
    # For now this depends on torchaudio but the exact dependency might evolve in the future.
    return _torchaudio_available

and in setup.py we have

extras["speech"] = deps_list("torchaudio") + extras["audio"]

If we want to require torchaudio + audio for the tests, then is_speech_available is not enough.
Also the message test requires speech will be confusing as speech is not a python library.

If we only want to require torchaudio, then we have def require_torchaudio(test_case):.

cc @eustlb

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would be nice to figure out if is_speech_available is needed or redundant. If it isn't, I will change the decorator to require_torchaudio. I have no idea on this though.
As for the message, we can change it to when the packages required for speech models are not available ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy to merge for now if the message is changed to

test requires torchaudio

And come back to check the details later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!



def get_gpu_count():
"""
Return the number of available gpus (regardless of whether torch, tf or jax is used)
Expand Down
5 changes: 4 additions & 1 deletion tests/models/seamless_m4t/test_modeling_seamless_m4t.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import unittest

from transformers import SeamlessM4TConfig, is_speech_available, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import require_speech, require_torch, slow, torch_device
from transformers.trainer_utils import set_seed
from transformers.utils import cached_property

Expand Down Expand Up @@ -1028,6 +1028,7 @@ def test_to_swh_text(self):

self.assertListAlmostEqual(expected_wav_slice, output.waveform.squeeze().tolist()[50:60])

@require_speech
@slow
def test_to_rus_speech(self):
model = SeamlessM4TModel.from_pretrained(self.repo_id).to(torch_device)
Expand Down Expand Up @@ -1066,6 +1067,7 @@ def test_text_to_text_model(self):
}
self.factory_test_task(SeamlessM4TModel, SeamlessM4TForTextToText, self.input_text, kwargs1, kwargs2)

@require_speech
@slow
def test_speech_to_text_model(self):
kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True, "generate_speech": False}
Expand All @@ -1077,6 +1079,7 @@ def test_speech_to_text_model(self):
}
self.factory_test_task(SeamlessM4TModel, SeamlessM4TForSpeechToText, self.input_audio, kwargs1, kwargs2)

@require_speech
@slow
def test_speech_to_speech_model(self):
kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import unittest

from transformers import SeamlessM4Tv2Config, is_speech_available, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import require_speech, require_torch, slow, torch_device
from transformers.trainer_utils import set_seed
from transformers.utils import cached_property

Expand Down Expand Up @@ -1095,6 +1095,7 @@ def test_to_swh_text(self):
[-2.001826e-04, 8.580012e-02], [output.waveform.mean().item(), output.waveform.std().item()]
)

@require_speech
@slow
def test_to_rus_speech(self):
model = SeamlessM4Tv2Model.from_pretrained(self.repo_id).to(torch_device)
Expand Down Expand Up @@ -1139,6 +1140,7 @@ def test_text_to_text_model(self):
}
self.factory_test_task(SeamlessM4Tv2Model, SeamlessM4Tv2ForTextToText, self.input_text, kwargs1, kwargs2)

@require_speech
@slow
def test_speech_to_text_model(self):
kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True, "generate_speech": False}
Expand All @@ -1150,6 +1152,7 @@ def test_speech_to_text_model(self):
}
self.factory_test_task(SeamlessM4Tv2Model, SeamlessM4Tv2ForSpeechToText, self.input_audio, kwargs1, kwargs2)

@require_speech
@slow
def test_speech_to_speech_model(self):
kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True}
Expand Down