Skip to content
1 change: 1 addition & 0 deletions docs/source/en/model_doc/whisper.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ The original code can be found [here](https://github.com/openai/whisper).

[[autodoc]] WhisperFeatureExtractor
- __call__
- _mask_input_features

## WhisperProcessor

Expand Down
232 changes: 231 additions & 1 deletion src/transformers/models/whisper/feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Feature extractor class for Whisper
"""

from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import numpy as np
from numpy.fft import fft
Expand All @@ -29,6 +29,124 @@
logger = logging.get_logger(__name__)


# Modified from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices with attention_mask from torch.LongTensor to np.array
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[np.array] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.

Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
Comment on lines +45 to +55
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also add the function to the model_dox/whisper.mdx to have it appear in the documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completed docstring of _mask_input_features and added it to model_dox/whisper.mdx

"""
batch_size, sequence_length = shape

if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")

if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)

# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()

def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)

# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length

# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)

return num_masked_span

# compute number of masked spans in batch
input_lengths = (
attention_mask.sum(-1).tolist() if attention_mask is not None else [sequence_length for _ in range(batch_size)]
)

# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []

max_num_masked_span = compute_num_masked_span(sequence_length)

if max_num_masked_span == 0:
return spec_aug_mask

for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)

# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)

# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]

spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)

spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)

# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)

# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets

# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1

# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)

return spec_aug_mask


class WhisperFeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs a Whisper feature extractor.
Expand All @@ -53,6 +171,33 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
Size of the Fourier transform.
padding_value (`float`, *optional*, defaults to 0.0):
Padding value used to pad the audio. Should correspond to silences.
return_attention_mask (`bool`, *optional*):
Whether to return the attention mask.
mask_time_prob (`float`, *optional*, defaults to 0.0):
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
mask_time_length (`int`, *optional*, defaults to 10):
Length of vector span along the time axis.
mask_time_min_masks (`int`, *optional*, defaults to 2),:
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
mask_time_min_masks''
mask_feature_prob (`float`, *optional*, defaults to 0.0):
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
True`.
mask_feature_length (`int`, *optional*, defaults to 10):
Length of vector span along the feature axis.
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
step, irrespectively of `mask_feature_prob`. Only relevant if
''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
"""

model_input_names = ["input_features"]
Expand All @@ -66,6 +211,12 @@ def __init__(
n_fft=400,
padding_value=0.0,
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
mask_time_prob=0.0,
mask_time_length=10,
mask_time_min_masks=2,
mask_feature_prob=0.0,
mask_feature_length=10,
mask_feature_min_masks=0,
**kwargs
):
super().__init__(
Expand All @@ -82,6 +233,13 @@ def __init__(
self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate
self.mel_filters = self.get_mel_filters(sampling_rate, n_fft, n_mels=feature_size)
# SpecAugment related
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.mask_time_min_masks = mask_time_min_masks
self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length
self.mask_feature_min_masks = mask_feature_min_masks

def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
# Initialize the weights
Expand Down Expand Up @@ -215,6 +373,60 @@ def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:

return log_spec

# Modified from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_input_features(
self,
input_features: np.ndarray,
mask_time_indices: Optional[np.ndarray] = None,
attention_mask: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).

Args:
input_features (`np.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained
by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`,
*e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`,
the [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion
into a tensor of type `torch.FloatTensor`.
mask_time_indices (`np.ndarray`, *optional*, defaults to None):
Indices to mask extracted features along time axis.
attention_mask (`np.ndarray`, *optional*, defaults to None):
A (right-padded) attention mask which independently shortens the feature axis of each batch dimension.
"""

# generate indices & apply SpecAugment along time axis
batch_size, hidden_size, sequence_length = input_features.shape

if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices
input_features[mask_time_indices] = 0
elif self.mask_time_prob > 0:
# generate indices & apply SpecAugment along time axis
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.mask_time_prob,
mask_length=self.mask_time_length,
attention_mask=attention_mask,
min_masks=self.mask_time_min_masks,
)
mask_time_indices = np.broadcast_to(mask_time_indices[:, None], (batch_size, hidden_size, sequence_length))
input_features[mask_time_indices] = 0

if self.mask_feature_prob > 0:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.mask_feature_prob,
mask_length=self.mask_feature_length,
min_masks=self.mask_feature_min_masks,
)
input_features[mask_feature_indices] = 0

return input_features

def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
Expand All @@ -225,6 +437,7 @@ def __call__(
padding: Optional[str] = "max_length",
max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
apply_spec_augment: bool = False,
**kwargs
) -> BatchFeature:
"""
Expand Down Expand Up @@ -266,6 +479,10 @@ def __call__(
pipeline.
padding_value (`float`, defaults to 0.0):
The value that is used to fill the padding values / vectors.
apply_spec_augment (`bool`, *optional*, defaults to `False`):
Whether to apply *SpecAugment* data augmentation to the log-Mel spectrogram features. For reference see
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
Recognition](https://arxiv.org/abs/1904.08779).
"""

if sampling_rate is not None:
Expand All @@ -281,6 +498,8 @@ def __call__(
"Failing to do so can result in silent errors that might be hard to debug."
)

return_attention_mask = return_attention_mask or apply_spec_augment

is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
Expand All @@ -307,17 +526,28 @@ def __call__(
max_length=max_length if max_length else self.n_samples,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
# make sure list is in array format
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)

# mono
input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]]

if isinstance(input_features[0], List):
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
else:
padded_inputs["input_features"] = input_features

if apply_spec_augment:
# todo: input_features to np array
padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], 0)

padded_inputs["input_features"] = self._mask_input_features(
padded_inputs["input_features"],
attention_mask=padded_inputs["attention_mask"][:, :: self.hop_length],
)

if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)

Expand Down
11 changes: 11 additions & 0 deletions tests/models/whisper/test_feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,14 @@ def test_integration(self):
feaure_extractor = WhisperFeatureExtractor()
input_features = feaure_extractor(input_speech, return_tensors="pt").input_features
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))

def test_mask_feat(self):
input_speech = self._load_datasamples(1)
feaure_extractor = WhisperFeatureExtractor(
mask_time_prob=0.1, mask_feature_prob=0.1, mask_time_min_masks=1, mask_feature_min_masks=1
)
input_features = feaure_extractor(input_speech, sampling_rate=16_000, apply_spec_augment=True).input_features
# at least feaure_extractor.mask_time_length samples along time should be masked
self.assertTrue((input_features[0, 0] == 0).sum() >= feaure_extractor.mask_time_length)
# at least feaure_extractor.mask_feature_length samples along feature should be masked
self.assertTrue((input_features[0, :, 0] == 0).sum() >= feaure_extractor.mask_feature_length)