diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index ebbfcfa4c501..03753f104deb 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -24,6 +24,8 @@ if TYPE_CHECKING: + from pyctcdecode import BeamSearchDecoderCTC + from ...feature_extraction_sequence_utils import SequenceFeatureExtractor logger = logging.get_logger(__name__) @@ -169,8 +171,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): """ - def __init__(self, feature_extractor: Union["SequenceFeatureExtractor", str], *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + feature_extractor: Union["SequenceFeatureExtractor", str], + *, + decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None, + **kwargs + ): + super().__init__(**kwargs) self.feature_extractor = feature_extractor if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values(): @@ -178,9 +186,9 @@ def __init__(self, feature_extractor: Union["SequenceFeatureExtractor", str], *a elif ( feature_extractor._processor_class and feature_extractor._processor_class.endswith("WithLM") - and kwargs.get("decoder", None) is not None + and decoder is not None ): - self.decoder = kwargs["decoder"] + self.decoder = decoder self.type = "ctc_with_lm" else: self.type = "ctc" @@ -221,6 +229,12 @@ def __call__( `timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ", "timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds. + generate_kwargs (`dict`, *optional*): + The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a + complete overview of generate, check the [following + guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). + max_new_tokens (`int`, *optional*): + The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. Return: `Dict`: A dictionary with the following keys: @@ -233,23 +247,43 @@ def __call__( """ return super().__call__(inputs, **kwargs) - def _sanitize_parameters(self, **kwargs): + def _sanitize_parameters( + self, + chunk_length_s=None, + stride_length_s=None, + ignore_warning=None, + decoder_kwargs=None, + return_timestamps=None, + generate_kwargs=None, + max_new_tokens=None, + ): # No parameters on this pipeline right now preprocess_params = {} - if "chunk_length_s" in kwargs: - preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"] - if "stride_length_s" in kwargs: - preprocess_params["stride_length_s"] = kwargs["stride_length_s"] - if "ignore_warning" in kwargs: - preprocess_params["ignore_warning"] = kwargs["ignore_warning"] + if chunk_length_s is not None: + preprocess_params["chunk_length_s"] = chunk_length_s + if stride_length_s is not None: + preprocess_params["stride_length_s"] = stride_length_s + if ignore_warning is not None: + preprocess_params["ignore_warning"] = ignore_warning + + forward_params = {"generate_kwargs": {}} + if max_new_tokens is not None: + forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens + if generate_kwargs is not None: + if max_new_tokens is not None and "max_new_tokens" in generate_kwargs: + raise ValueError( + "`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use" + " only 1 version" + ) + forward_params["generate_kwargs"].update(generate_kwargs) postprocess_params = {} - if "decoder_kwargs" in kwargs: - postprocess_params["decoder_kwargs"] = kwargs["decoder_kwargs"] - if "return_timestamps" in kwargs: - postprocess_params["return_timestamps"] = kwargs["return_timestamps"] + if decoder_kwargs is not None: + postprocess_params["decoder_kwargs"] = decoder_kwargs + if return_timestamps is not None: + postprocess_params["return_timestamps"] = return_timestamps - return preprocess_params, {}, postprocess_params + return preprocess_params, forward_params, postprocess_params def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False): if isinstance(inputs, str): @@ -351,7 +385,10 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warn processed["stride"] = stride yield {"is_last": True, **processed, **extra} - def _forward(self, model_inputs): + def _forward(self, model_inputs, generate_kwargs=None): + if generate_kwargs is None: + generate_kwargs = {} + is_last = model_inputs.pop("is_last") if self.type == "seq2seq": encoder = self.model.get_encoder() @@ -376,6 +413,7 @@ def _forward(self, model_inputs): tokens = self.model.generate( encoder_outputs=encoder(inputs, attention_mask=attention_mask), attention_mask=attention_mask, + **generate_kwargs, ) out = {"tokens": tokens} diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 88a9a088f019..5487d3ff1246 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -169,6 +169,17 @@ def test_small_model_pt_seq2seq(self): output = speech_recognizer(waveform) self.assertEqual(output, {"text": "あл ش 湯 清 ه ܬ া लᆨしث ल eか u w 全 u"}) + @require_torch + def test_small_model_pt_seq2seq_gen_kwargs(self): + speech_recognizer = pipeline( + model="hf-internal-testing/tiny-random-speech-encoder-decoder", + framework="pt", + ) + + waveform = np.tile(np.arange(1000, dtype=np.float32), 34) + output = speech_recognizer(waveform, max_new_tokens=10, generate_kwargs={"num_beams": 2}) + self.assertEqual(output, {"text": "あл † γ ت ב オ 束 泣 足"}) + @slow @require_torch @require_pyctcdecode