Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 35 additions & 39 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ def test_torch_large(self):
self.assertEqual(output, {"text": ""})

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})

@require_torch
Expand All @@ -312,8 +312,8 @@ def test_torch_large_with_input_features(self):
self.assertEqual(output, {"text": ""})

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})

@slow
Expand Down Expand Up @@ -542,11 +542,11 @@ def test_torch_whisper(self):
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})

output = speech_recognizer([filename], chunk_length_s=5, batch_size=4)
output = speech_recognizer([ds[40]["audio"]], chunk_length_s=5, batch_size=4)
self.assertEqual(output, [{"text": " A man said to the universe, Sir, I exist."}])

@require_torch
Expand Down Expand Up @@ -1014,8 +1014,8 @@ def test_torch_speech_encoder_decoder(self):
)

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": 'Ein Mann sagte zum Universum : " Sir, ich existiert! "'})

@slow
Expand All @@ -1032,13 +1032,11 @@ def test_simple_wav2vec2(self):
self.assertEqual(output, {"text": ""})

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = asr(filename)
audio = ds[40]["audio"]
output = asr(audio)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})

filename = ds[40]["file"]
with open(filename, "rb") as f:
data = f.read()
data = Audio().encode_example(ds[40]["audio"])["bytes"]
output = asr(data)
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})

Expand All @@ -1058,13 +1056,11 @@ def test_simple_s2t(self):
self.assertEqual(output, {"text": "(Applausi)"})

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = asr(filename)
audio = ds[40]["audio"]
output = asr(audio)
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})

filename = ds[40]["file"]
with open(filename, "rb") as f:
data = f.read()
data = Audio().encode_example(ds[40]["audio"])["bytes"]
output = asr(data)
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})

Expand All @@ -1078,13 +1074,13 @@ def test_simple_whisper_asr(self):
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
filename = ds[0]["file"]
output = speech_recognizer(filename)
audio = ds[0]["audio"]
output = speech_recognizer(audio)
self.assertEqual(
output,
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
)
output = speech_recognizer(filename, return_timestamps=True)
output = speech_recognizer(ds[0]["audio"], return_timestamps=True)
self.assertEqual(
output,
{
Expand All @@ -1100,7 +1096,7 @@ def test_simple_whisper_asr(self):
},
)
speech_recognizer.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
output = speech_recognizer(filename, return_timestamps="word")
output = speech_recognizer(ds[0]["audio"], return_timestamps="word")
# fmt: off
self.assertEqual(
output,
Expand Down Expand Up @@ -1135,7 +1131,7 @@ def test_simple_whisper_asr(self):
"^Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
"Use `return_timestamps='word'` or `return_timestamps=True` respectively.$",
):
_ = speech_recognizer(filename, return_timestamps="char")
_ = speech_recognizer(audio, return_timestamps="char")

@slow
@require_torch
Expand All @@ -1147,8 +1143,8 @@ def test_simple_whisper_translation(self):
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
Expand All @@ -1158,7 +1154,7 @@ def test_simple_whisper_translation(self):
speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
output_2 = speech_recognizer_2(filename)
output_2 = speech_recognizer_2(ds[0]["audio"])
self.assertEqual(output, output_2)

# either use generate_kwargs or set the model's generation_config
Expand All @@ -1170,7 +1166,7 @@ def test_simple_whisper_translation(self):
feature_extractor=feature_extractor,
generate_kwargs={"task": "transcribe", "language": "<|it|>"},
)
output_3 = speech_translator(filename)
output_3 = speech_translator(ds[0]["audio"])
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})

@slow
Expand All @@ -1182,10 +1178,10 @@ def test_whisper_language(self):
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
filename = ds[0]["file"]
audio = ds[0]["audio"]

# 1. English-only model compatible with no language argument
output = speech_recognizer(filename)
output = speech_recognizer(audio)
self.assertEqual(
output,
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
Expand All @@ -1197,15 +1193,15 @@ def test_whisper_language(self):
"Cannot specify `task` or `language` for an English-only model. If the model is intended to be multilingual, "
"pass `is_multilingual=True` to generate, or update the generation config.",
):
_ = speech_recognizer(filename, generate_kwargs={"language": "en"})
_ = speech_recognizer(ds[0]["audio"], generate_kwargs={"language": "en"})

# 3. Multilingual model accepts language argument
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
framework="pt",
)
output = speech_recognizer(filename, generate_kwargs={"language": "en"})
output = speech_recognizer(ds[0]["audio"], generate_kwargs={"language": "en"})
self.assertEqual(
output,
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
Expand Down Expand Up @@ -1315,8 +1311,8 @@ def test_xls_r_to_en(self):
)

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "A man said to the universe: “Sir, I exist."})

@slow
Expand All @@ -1331,8 +1327,8 @@ def test_xls_r_from_en(self):
)

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})

@slow
Expand All @@ -1348,9 +1344,8 @@ def test_speech_to_text_leveraged(self):
)

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]

output = speech_recognizer(filename)
audio = ds[40]["audio"]
output = speech_recognizer(audio)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})

@slow
Expand Down Expand Up @@ -1561,6 +1556,7 @@ def test_whisper_longform(self):
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
device=torch_device,
return_timestamps=True, # to allow longform generation
)

ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
Expand Down