diff --git a/docs/source/en/model_doc/whisper.mdx b/docs/source/en/model_doc/whisper.mdx index 4b7a60286184..4c6a79d3fad3 100644 --- a/docs/source/en/model_doc/whisper.mdx +++ b/docs/source/en/model_doc/whisper.mdx @@ -49,6 +49,7 @@ The original code can be found [here](https://github.com/openai/whisper). [[autodoc]] WhisperFeatureExtractor - __call__ + - _mask_input_features ## WhisperProcessor diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 5a328db65639..fd0eba32d384 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -16,7 +16,7 @@ Feature extractor class for Whisper """ -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np from numpy.fft import fft @@ -29,6 +29,124 @@ logger = logging.get_logger(__name__) +# Modified from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices with attention_mask from torch.LongTensor to np.array +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[np.array] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.sum(-1).tolist() if attention_mask is not None else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + class WhisperFeatureExtractor(SequenceFeatureExtractor): r""" Constructs a Whisper feature extractor. @@ -53,6 +171,33 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): Size of the Fourier transform. padding_value (`float`, *optional*, defaults to 0.0): Padding value used to pad the audio. Should correspond to silences. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. + mask_time_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If + reasoning from the propability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`. + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2),: + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks'' + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over + the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0),: + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks'' """ model_input_names = ["input_features"] @@ -66,6 +211,12 @@ def __init__( n_fft=400, padding_value=0.0, return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask + mask_time_prob=0.0, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, **kwargs ): super().__init__( @@ -82,6 +233,13 @@ def __init__( self.nb_max_frames = self.n_samples // hop_length self.sampling_rate = sampling_rate self.mel_filters = self.get_mel_filters(sampling_rate, n_fft, n_mels=feature_size) + # SpecAugment related + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32): # Initialize the weights @@ -215,6 +373,60 @@ def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray: return log_spec + # Modified from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states + def _mask_input_features( + self, + input_features: np.ndarray, + mask_time_indices: Optional[np.ndarray] = None, + attention_mask: Optional[np.ndarray] = None, + ) -> np.ndarray: + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + + Args: + input_features (`np.ndarray` of shape `(batch_size, feature_size, sequence_length)`): + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained + by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, + *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, + the [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion + into a tensor of type `torch.FloatTensor`. + mask_time_indices (`np.ndarray`, *optional*, defaults to None): + Indices to mask extracted features along time axis. + attention_mask (`np.ndarray`, *optional*, defaults to None): + A (right-padded) attention mask which independently shortens the feature axis of each batch dimension. + """ + + # generate indices & apply SpecAugment along time axis + batch_size, hidden_size, sequence_length = input_features.shape + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + input_features[mask_time_indices] = 0 + elif self.mask_time_prob > 0: + # generate indices & apply SpecAugment along time axis + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.mask_time_prob, + mask_length=self.mask_time_length, + attention_mask=attention_mask, + min_masks=self.mask_time_min_masks, + ) + mask_time_indices = np.broadcast_to(mask_time_indices[:, None], (batch_size, hidden_size, sequence_length)) + input_features[mask_time_indices] = 0 + + if self.mask_feature_prob > 0: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.mask_feature_prob, + mask_length=self.mask_feature_length, + min_masks=self.mask_feature_min_masks, + ) + input_features[mask_feature_indices] = 0 + + return input_features + def __call__( self, raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], @@ -225,6 +437,7 @@ def __call__( padding: Optional[str] = "max_length", max_length: Optional[int] = None, sampling_rate: Optional[int] = None, + apply_spec_augment: bool = False, **kwargs ) -> BatchFeature: """ @@ -266,6 +479,10 @@ def __call__( pipeline. padding_value (`float`, defaults to 0.0): The value that is used to fill the padding values / vectors. + apply_spec_augment (`bool`, *optional*, defaults to `False`): + Whether to apply *SpecAugment* data augmentation to the log-Mel spectrogram features. For reference see + [SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition](https://arxiv.org/abs/1904.08779). """ if sampling_rate is not None: @@ -281,6 +498,8 @@ def __call__( "Failing to do so can result in silent errors that might be hard to debug." ) + return_attention_mask = return_attention_mask or apply_spec_augment + is_batched = bool( isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) @@ -307,10 +526,12 @@ def __call__( max_length=max_length if max_length else self.n_samples, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, ) # make sure list is in array format input_features = padded_inputs.get("input_features").transpose(2, 0, 1) + # mono input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]] if isinstance(input_features[0], List): @@ -318,6 +539,15 @@ def __call__( else: padded_inputs["input_features"] = input_features + if apply_spec_augment: + # todo: input_features to np array + padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], 0) + + padded_inputs["input_features"] = self._mask_input_features( + padded_inputs["input_features"], + attention_mask=padded_inputs["attention_mask"][:, :: self.hop_length], + ) + if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) diff --git a/tests/models/whisper/test_feature_extraction_whisper.py b/tests/models/whisper/test_feature_extraction_whisper.py index c03763cdf63f..8ee97170195c 100644 --- a/tests/models/whisper/test_feature_extraction_whisper.py +++ b/tests/models/whisper/test_feature_extraction_whisper.py @@ -223,3 +223,14 @@ def test_integration(self): feaure_extractor = WhisperFeatureExtractor() input_features = feaure_extractor(input_speech, return_tensors="pt").input_features self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4)) + + def test_mask_feat(self): + input_speech = self._load_datasamples(1) + feaure_extractor = WhisperFeatureExtractor( + mask_time_prob=0.1, mask_feature_prob=0.1, mask_time_min_masks=1, mask_feature_min_masks=1 + ) + input_features = feaure_extractor(input_speech, sampling_rate=16_000, apply_spec_augment=True).input_features + # at least feaure_extractor.mask_time_length samples along time should be masked + self.assertTrue((input_features[0, 0] == 0).sum() >= feaure_extractor.mask_time_length) + # at least feaure_extractor.mask_feature_length samples along feature should be masked + self.assertTrue((input_features[0, :, 0] == 0).sum() >= feaure_extractor.mask_feature_length)