diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index f6c7b0167f6d..dd97449ca8ae 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -56,7 +56,7 @@ def rescale_stride(stride, ratio): return new_strides -def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, ratio, dtype=None): +def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None): inputs_len = inputs.shape[0] step = chunk_len - stride_left - stride_right for i in range(0, inputs_len, step): @@ -68,9 +68,15 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, _stride_left = 0 if i == 0 else stride_left is_last = i + step + stride_left >= inputs_len _stride_right = 0 if is_last else stride_right + chunk_len = chunk.shape[0] stride = (chunk_len, _stride_left, _stride_right) - if ratio != 1: + if "input_features" in processed: + processed_len = processed["input_features"].shape[-1] + elif "input_values" in processed: + processed_len = processed["input_values"].shape[-1] + if processed_len != chunk.shape[-1] and rescale: + ratio = processed_len / chunk_len stride = rescale_stride([stride], ratio)[0] if chunk.shape[0] > _stride_left: yield {"is_last": is_last, "stride": stride, **processed} @@ -101,10 +107,10 @@ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source sequence = sequence[begin_idx:] timestamp_tokens = sequence >= timestamp_begin - consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 - last_timestamp = np.where(timestamp_tokens)[0][-1] - consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive - if seq_idx != 0: + if seq_idx != 0 and sum(timestamp_tokens) > 0: + consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 + last_timestamp = np.where(timestamp_tokens)[0][-1] + consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive time -= stride_left + stride_right offset = int((time / feature_extractor.sampling_rate) / time_precision) overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision) @@ -400,13 +406,12 @@ def _sanitize_parameters( " only 1 version" ) forward_params["generate_kwargs"].update(generate_kwargs) - if return_timestamps is not None: - forward_params["generate_kwargs"]["return_timestamps"] = return_timestamps postprocess_params = {} if decoder_kwargs is not None: postprocess_params["decoder_kwargs"] = decoder_kwargs if return_timestamps is not None: + forward_params["return_timestamps"] = return_timestamps postprocess_params["return_timestamps"] = return_timestamps if self.model.config.model_type == "whisper": # Whisper is highly specific, if we want timestamps, we need to @@ -502,9 +507,10 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warn if chunk_len < stride_left + stride_right: raise ValueError("Chunk length must be superior to stride length") + rescale = self.type != "seq2seq_whisper" # make sure that for item in chunk_iter( - inputs, self.feature_extractor, chunk_len, stride_left, stride_right, align_to, self.torch_dtype + inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype ): yield item else: @@ -520,12 +526,11 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warn processed["stride"] = stride yield {"is_last": True, **processed, **extra} - def _forward(self, model_inputs, generate_kwargs=None): + def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None): if generate_kwargs is None: generate_kwargs = {} is_last = model_inputs.pop("is_last") - return_timestamps = generate_kwargs.pop("return_timestamps", False) if self.type == "seq2seq": encoder = self.model.get_encoder() @@ -635,9 +640,9 @@ def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, retu # Simply cast from pyctcdecode format to wav2vec2 format to leverage # pre-existing code later chunk_offset = beams[0][2] - word_offsets = [] + offsets = [] for word, (start_offset, end_offset) in chunk_offset: - word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) + offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) else: skip_special_tokens = self.type != "ctc" text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens) diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index d21c00f8acea..3a5dcc7f4304 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -201,8 +201,9 @@ def test_small_model_pt_seq2seq_gen_kwargs(self): @require_torch @require_pyctcdecode def test_large_model_pt_with_lm(self): - dataset = load_dataset("Narsil/asr_dummy") - filename = dataset["test"][3]["file"] + dataset = load_dataset("Narsil/asr_dummy", streaming=True) + third_item = next(iter(dataset["test"].skip(3))) + filename = third_item["file"] speech_recognizer = pipeline( task="automatic-speech-recognition",