diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 0160237304b6..24eb72a0b0f9 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -877,9 +877,7 @@ def new_chunk(): if previous_tokens: if return_timestamps: - # Last token should always be timestamps, so there shouldn't be - # leftover - raise ValueError( + logger.warning( "There was an error while processing timestamps, we haven't found a timestamp as last token. Was" " WhisperTimeStampLogitsProcessor used?" ) diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 5db3e3e46c72..952508dca441 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -17,7 +17,7 @@ import numpy as np import pytest from datasets import load_dataset -from huggingface_hub import snapshot_download +from huggingface_hub import hf_hub_download, snapshot_download from transformers import ( MODEL_FOR_CTC_MAPPING, @@ -39,6 +39,7 @@ require_pyctcdecode, require_tf, require_torch, + require_torch_gpu, require_torchaudio, slow, ) @@ -1158,6 +1159,36 @@ def test_stride(self): output = speech_recognizer({"raw": waveform, "stride": (1000, 8000), "sampling_rate": 16_000}) self.assertEqual(output, {"text": "XB"}) + @slow + @require_torch_gpu + def test_slow_unfinished_sequence(self): + from transformers import GenerationConfig + + pipe = pipeline( + "automatic-speech-recognition", + model="vasista22/whisper-hindi-large-v2", + device="cuda:0", + ) + # Original model wasn't trained with timestamps and has incorrect generation config + pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2") + + audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset") + + out = pipe( + audio, + return_timestamps=True, + ) + self.assertEqual( + out, + { + "chunks": [ + {"text": "", "timestamp": (18.94, 0.0)}, + {"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)}, + ], + "text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", + }, + ) + def require_ffmpeg(test_case): """