diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 08c568e78a9c..2780355d953b 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -56,14 +56,15 @@ def rescale_stride(stride, ratio): def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None): inputs_len = inputs.shape[0] step = chunk_len - stride_left - stride_right - for i in range(0, inputs_len, step): - # add start and end paddings to the chunk - chunk = inputs[i : i + chunk_len] + for chunk_start_idx in range(0, inputs_len, step): + chunk_end_idx = chunk_start_idx + chunk_len + chunk = inputs[chunk_start_idx:chunk_end_idx] processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") if dtype is not None: processed = processed.to(dtype=dtype) - _stride_left = 0 if i == 0 else stride_left - is_last = i + step + stride_left >= inputs_len + _stride_left = 0 if chunk_start_idx == 0 else stride_left + # all right strides must be full, otherwise it is the last item + is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len _stride_right = 0 if is_last else stride_right chunk_len = chunk.shape[0] @@ -77,6 +78,8 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, stride = rescale_stride([stride], ratio)[0] if chunk.shape[0] > _stride_left: yield {"is_last": is_last, "stride": stride, **processed} + if is_last: + break def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions): diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 2bda09fe00c7..f5b0f78dff47 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -526,7 +526,7 @@ def test_whisper_timestamp_prediction(self): output = pipe(array, chunk_length_s=10) self.assertDictEqual( - output, + nested_simplify(output), { "chunks": [ {"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 5.5)}, @@ -548,11 +548,11 @@ def test_whisper_timestamp_prediction(self): }, { "text": " the thousands of spectators, retrievality is not worth thinking about.", - "timestamp": (19.6, 24.98), + "timestamp": (19.6, 26.66), }, { "text": " His instant panic was followed by a small, sharp blow high on his chest.", - "timestamp": (24.98, 30.98), + "timestamp": (26.66, 31.06), }, ], "text": ( @@ -1110,6 +1110,11 @@ def test_chunk_iterator_stride(self): self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)]) self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)]) + outs = list(chunk_iter(inputs, feature_extractor, 36, 6, 6, ratio)) + self.assertEqual(len(outs), 4) + self.assertEqual([o["stride"] for o in outs], [(36, 0, 6), (36, 6, 6), (36, 6, 6), (28, 6, 0)]) + self.assertEqual([o["input_values"].shape for o in outs], [(1, 36), (1, 36), (1, 36), (1, 28)]) + inputs = torch.LongTensor([i % 2 for i in range(100)]) input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[ "input_values"