Skip to content
18 changes: 13 additions & 5 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
for chunk_start_idx in range(0, inputs_len, step):
chunk_end_idx = chunk_start_idx + chunk_len
chunk = inputs[chunk_start_idx:chunk_end_idx]
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
processed = feature_extractor(
chunk,
sampling_rate=feature_extractor.sampling_rate,
return_tensors="pt",
return_attention_mask=True,
)
if dtype is not None:
processed = processed.to(dtype=dtype)
_stride_left = 0 if chunk_start_idx == 0 else stride_left
Expand Down Expand Up @@ -507,11 +512,14 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config

tokens = self.model.generate(
inputs=inputs,
attention_mask=attention_mask,
main_input_name = self.model.main_input_name if hasattr(self.model, "main_input_name") else "inputs"
generate_kwargs = {
main_input_name: inputs,
"attention_mask": attention_mask,
**generate_kwargs,
)
}
tokens = self.model.generate(**generate_kwargs)

# whisper longform generation stores timestamps in "segments"
if return_timestamps == "word" and self.type == "seq2seq_whisper":
if "segments" not in tokens:
Expand Down