Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions tests/models/wav2vec2/test_modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout):
try:
_ = in_queue.get(timeout=timeout)

ds = load_dataset("common_voice", "es", split="test", streaming=True)
ds = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="test", streaming=True)
sample = next(iter(ds))

resampled_audio = torchaudio.functional.resample(
Expand All @@ -119,15 +119,15 @@ def _test_wav2vec2_with_lm_invalid_pool(in_queue, out_queue, timeout):
transcription = processor.batch_decode(logits.cpu().numpy(), pool).text

unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
unittest.TestCase().assertEqual(transcription[0], "habitan aguas poco profundas y rocosas")

# force batch_decode to internally create a spawn pool, which should trigger a warning if different than fork
multiprocessing.set_start_method("spawn", force=True)
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl:
transcription = processor.batch_decode(logits.cpu().numpy()).text

unittest.TestCase().assertIn("Falling back to sequential decoding.", cl.out)
unittest.TestCase().assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
unittest.TestCase().assertEqual(transcription[0], "habitan aguas poco profundas y rocosas")
except Exception:
error = f"{traceback.format_exc()}"

Expand Down Expand Up @@ -1833,7 +1833,7 @@ def test_phoneme_recognition(self):
@require_pyctcdecode
@require_torchaudio
def test_wav2vec2_with_lm(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
ds = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="test", streaming=True)
sample = next(iter(ds))

resampled_audio = torchaudio.functional.resample(
Expand All @@ -1852,12 +1852,12 @@ def test_wav2vec2_with_lm(self):

transcription = processor.batch_decode(logits.cpu().numpy()).text

self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
self.assertEqual(transcription[0], "habitan aguas poco profundas y rocosas")

@require_pyctcdecode
@require_torchaudio
def test_wav2vec2_with_lm_pool(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
ds = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="test", streaming=True)
sample = next(iter(ds))

resampled_audio = torchaudio.functional.resample(
Expand All @@ -1878,7 +1878,7 @@ def test_wav2vec2_with_lm_pool(self):
with multiprocessing.get_context("fork").Pool(2) as pool:
transcription = processor.batch_decode(logits.cpu().numpy(), pool).text

self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
self.assertEqual(transcription[0], "habitan aguas poco profundas y rocosas")

# user-managed pool + num_processes should trigger a warning
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
Expand All @@ -1889,7 +1889,7 @@ def test_wav2vec2_with_lm_pool(self):
self.assertIn("num_process", cl.out)
self.assertIn("it will be ignored", cl.out)

self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
self.assertEqual(transcription[0], "habitan aguas poco profundas y rocosas")

@require_pyctcdecode
@require_torchaudio
Expand Down Expand Up @@ -1957,7 +1957,7 @@ def test_inference_mms_1b_all(self):
LANG_MAP = {"it": "ita", "es": "spa", "fr": "fra", "en": "eng"}

def run_model(lang):
ds = load_dataset("common_voice", lang, split="test", streaming=True)
ds = load_dataset("mozilla-foundation/common_voice_11_0", lang, split="test", streaming=True)
sample = next(iter(ds))

wav2vec2_lang = LANG_MAP[lang]
Expand All @@ -1982,10 +1982,10 @@ def run_model(lang):
return transcription

TRANSCRIPTIONS = {
"it": "mi hanno fatto un'offerta che non potevo proprio rifiutare",
"es": "bien y qué regalo vas a abrir primero",
"fr": "un vrai travail intéressant va enfin être mené sur ce sujet",
"en": "twas the time of day and olof spen slept during the summer",
"it": "il libro ha suscitato molte polemiche a causa dei suoi contenuti",
"es": "habitan aguas poco profundas y rocosas",
"fr": "ce dernier est volé tout au long de l'histoire romaine",
"en": "joe keton disapproved of films and buster also had reservations about the media",
}

for lang in LANG_MAP.keys():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,15 +434,14 @@ def test_offsets_integration_fast_batch(self):
def test_word_time_stamp_integration(self):
import torch

ds = load_dataset("common_voice", "en", split="train", streaming=True)
ds = load_dataset("mozilla-foundation/common_voice_11_0", "en", split="train", streaming=True)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
ds_iter = iter(ds)
sample = next(ds_iter)

processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")

# compare to filename `common_voice_en_100038.mp3` of dataset viewer on https://huggingface.co/datasets/common_voice/viewer/en/train
input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values

with torch.no_grad():
Expand All @@ -461,6 +460,7 @@ def test_word_time_stamp_integration(self):
]

EXPECTED_TEXT = "WHY DOES MILISANDRA LOOK LIKE SHE WANTS TO CONSUME JOHN SNOW ON THE RIVER AT THE WALL"
EXPECTED_TEXT = "THE TRACK APPEARS ON THE COMPILATION ALBUM CRAFT FORKS"

# output words
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), EXPECTED_TEXT)
Expand All @@ -471,8 +471,8 @@ def test_word_time_stamp_integration(self):
end_times = torch.tensor(self.get_from_offsets(word_time_stamps, "end_time"))

# fmt: off
expected_start_tensor = torch.tensor([1.4199, 1.6599, 2.2599, 3.0, 3.24, 3.5999, 3.7999, 4.0999, 4.26, 4.94, 5.28, 5.6599, 5.78, 5.94, 6.32, 6.5399, 6.6599])
expected_end_tensor = torch.tensor([1.5399, 1.8999, 2.9, 3.16, 3.5399, 3.72, 4.0199, 4.1799, 4.76, 5.1599, 5.5599, 5.6999, 5.86, 6.1999, 6.38, 6.6199, 6.94])
expected_start_tensor = torch.tensor([0.6800, 0.8800, 1.1800, 1.8600, 1.9600, 2.1000, 3.0000, 3.5600, 3.9800])
expected_end_tensor = torch.tensor([0.7800, 1.1000, 1.6600, 1.9200, 2.0400, 2.8000, 3.3000, 3.8800, 4.2800])
# fmt: on

self.assertTrue(torch.allclose(start_times, expected_start_tensor, atol=0.01))
Expand Down