Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def rescale_stride(stride, ratio):
return new_strides


def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, ratio, dtype=None):
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None):
inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right
for i in range(0, inputs_len, step):
Expand All @@ -68,9 +68,15 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
_stride_left = 0 if i == 0 else stride_left
is_last = i + step + stride_left >= inputs_len
_stride_right = 0 if is_last else stride_right

chunk_len = chunk.shape[0]
stride = (chunk_len, _stride_left, _stride_right)
if ratio != 1:
if "input_features" in processed:
processed_len = processed["input_features"].shape[-1]
elif "input_values" in processed:
processed_len = processed["input_values"].shape[-1]
if processed_len != chunk.shape[-1] and rescale:
ratio = processed_len / chunk_len
Comment on lines +74 to +79
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was missing! Fixes the LM tests

stride = rescale_stride([stride], ratio)[0]
if chunk.shape[0] > _stride_left:
yield {"is_last": is_last, "stride": stride, **processed}
Expand Down Expand Up @@ -101,10 +107,10 @@ def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source
sequence = sequence[begin_idx:]

timestamp_tokens = sequence >= timestamp_begin
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
last_timestamp = np.where(timestamp_tokens)[0][-1]
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
if seq_idx != 0:
if seq_idx != 0 and sum(timestamp_tokens) > 0:
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
last_timestamp = np.where(timestamp_tokens)[0][-1]
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
Comment on lines -104 to +113
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just makes sure that if the model output no timestamps, we just don't throw an error

time -= stride_left + stride_right
offset = int((time / feature_extractor.sampling_rate) / time_precision)
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
Expand Down Expand Up @@ -400,13 +406,12 @@ def _sanitize_parameters(
" only 1 version"
)
forward_params["generate_kwargs"].update(generate_kwargs)
if return_timestamps is not None:
forward_params["generate_kwargs"]["return_timestamps"] = return_timestamps

postprocess_params = {}
if decoder_kwargs is not None:
postprocess_params["decoder_kwargs"] = decoder_kwargs
if return_timestamps is not None:
forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps
if self.model.config.model_type == "whisper":
# Whisper is highly specific, if we want timestamps, we need to
Expand Down Expand Up @@ -502,9 +507,10 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warn
if chunk_len < stride_left + stride_right:
raise ValueError("Chunk length must be superior to stride length")

rescale = self.type != "seq2seq_whisper"
# make sure that
for item in chunk_iter(
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, align_to, self.torch_dtype
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype
):
yield item
else:
Expand All @@ -520,12 +526,11 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warn
processed["stride"] = stride
yield {"is_last": True, **processed, **extra}

def _forward(self, model_inputs, generate_kwargs=None):
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding this argument prevent the .pop from removing it for other processes.

if generate_kwargs is None:
generate_kwargs = {}

is_last = model_inputs.pop("is_last")
return_timestamps = generate_kwargs.pop("return_timestamps", False)

if self.type == "seq2seq":
encoder = self.model.get_encoder()
Expand Down Expand Up @@ -635,9 +640,9 @@ def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, retu
# Simply cast from pyctcdecode format to wav2vec2 format to leverage
# pre-existing code later
chunk_offset = beams[0][2]
word_offsets = []
offsets = []
for word, (start_offset, end_offset) in chunk_offset:
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
Comment on lines -638 to +645
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalized the name of this variable

else:
skip_special_tokens = self.type != "ctc"
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,9 @@ def test_small_model_pt_seq2seq_gen_kwargs(self):
@require_torch
@require_pyctcdecode
def test_large_model_pt_with_lm(self):
dataset = load_dataset("Narsil/asr_dummy")
filename = dataset["test"][3]["file"]
dataset = load_dataset("Narsil/asr_dummy", streaming=True)
third_item = next(iter(dataset["test"].skip(3)))
filename = third_item["file"]

speech_recognizer = pipeline(
task="automatic-speech-recognition",
Expand Down