Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 55 additions & 17 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@


if TYPE_CHECKING:
from pyctcdecode import BeamSearchDecoderCTC

from ...feature_extraction_sequence_utils import SequenceFeatureExtractor

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -169,18 +171,24 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):

"""

def __init__(self, feature_extractor: Union["SequenceFeatureExtractor", str], *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
feature_extractor: Union["SequenceFeatureExtractor", str],
*,
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
**kwargs
):
super().__init__(**kwargs)
self.feature_extractor = feature_extractor

if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
self.type = "seq2seq"
elif (
feature_extractor._processor_class
and feature_extractor._processor_class.endswith("WithLM")
and kwargs.get("decoder", None) is not None
and decoder is not None
):
self.decoder = kwargs["decoder"]
self.decoder = decoder
self.type = "ctc_with_lm"
else:
self.type = "ctc"
Expand Down Expand Up @@ -221,6 +229,12 @@ def __call__(
`timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ",
"timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
generate_kwargs (`dict`, *optional*):
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.

Return:
`Dict`: A dictionary with the following keys:
Expand All @@ -233,23 +247,43 @@ def __call__(
"""
return super().__call__(inputs, **kwargs)

def _sanitize_parameters(self, **kwargs):
def _sanitize_parameters(
self,
chunk_length_s=None,
stride_length_s=None,
ignore_warning=None,
decoder_kwargs=None,
return_timestamps=None,
generate_kwargs=None,
max_new_tokens=None,
):
# No parameters on this pipeline right now
preprocess_params = {}
if "chunk_length_s" in kwargs:
preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"]
if "stride_length_s" in kwargs:
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
if "ignore_warning" in kwargs:
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]
if chunk_length_s is not None:
preprocess_params["chunk_length_s"] = chunk_length_s
if stride_length_s is not None:
preprocess_params["stride_length_s"] = stride_length_s
if ignore_warning is not None:
preprocess_params["ignore_warning"] = ignore_warning

forward_params = {"generate_kwargs": {}}
if max_new_tokens is not None:
forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens
if generate_kwargs is not None:
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
raise ValueError(
"`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
" only 1 version"
)
forward_params["generate_kwargs"].update(generate_kwargs)

postprocess_params = {}
if "decoder_kwargs" in kwargs:
postprocess_params["decoder_kwargs"] = kwargs["decoder_kwargs"]
if "return_timestamps" in kwargs:
postprocess_params["return_timestamps"] = kwargs["return_timestamps"]
if decoder_kwargs is not None:
postprocess_params["decoder_kwargs"] = decoder_kwargs
if return_timestamps is not None:
postprocess_params["return_timestamps"] = return_timestamps

return preprocess_params, {}, postprocess_params
return preprocess_params, forward_params, postprocess_params

def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
if isinstance(inputs, str):
Expand Down Expand Up @@ -351,7 +385,10 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warn
processed["stride"] = stride
yield {"is_last": True, **processed, **extra}

def _forward(self, model_inputs):
def _forward(self, model_inputs, generate_kwargs=None):
if generate_kwargs is None:
generate_kwargs = {}

is_last = model_inputs.pop("is_last")
if self.type == "seq2seq":
encoder = self.model.get_encoder()
Expand All @@ -376,6 +413,7 @@ def _forward(self, model_inputs):
tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask,
**generate_kwargs,
)

out = {"tokens": tokens}
Expand Down
11 changes: 11 additions & 0 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,17 @@ def test_small_model_pt_seq2seq(self):
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "あл ش 湯 清 ه ܬ া लᆨしث ल eか u w 全 u"})

@require_torch
def test_small_model_pt_seq2seq_gen_kwargs(self):
speech_recognizer = pipeline(
model="hf-internal-testing/tiny-random-speech-encoder-decoder",
framework="pt",
)

waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform, max_new_tokens=10, generate_kwargs={"num_beams": 2})
self.assertEqual(output, {"text": "あл † γ ت ב オ 束 泣 足"})

@slow
@require_torch
@require_pyctcdecode
Expand Down