Skip to content

Conversation

@bofenghuang
Copy link
Contributor

@bofenghuang bofenghuang commented Jan 9, 2023

What does this PR do?

Hi @ArthurZucker 👋,

As discussed in another conversation, in this PR I try to add SpecAugment to whisper models. It was used as one of regularization methods to train the large-v2 model (openai/whisper#661).

Here the SpecAugment is implemented into WhisperFeatureExtractor in numpy. It masks the computed fbank features along the time and the feature axis.

Here are the steps in my mind. Please correct me if I miss something.

  • Return attention_mask by pad function to get the actual input lengths in the batch. And rescale it from sample level to feature level (48000 -> 3000)
  • Copy _compute_mask_indices function of wav2vec2, which will be used to generate masks
  • Add _mask_input_features function to mask along time or feature axis
  • Add apply_spec_augment, mask_time_prob, etc to config and __call__ function

It's still in draft. I will add the parameters to config and fix the test errors later :)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jan 9, 2023

The documentation is not available anymore as the PR was closed or merged.

@samuelazran
Copy link

Seems like a very needed feature! what is the status? was this functionality tested?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very good! Thanks for the contribution.
I think you can add the argument to the config, and start working on the return_attention_mask that should probably only be set if apply_spec_augment

Comment on lines +45 to +55
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also add the function to the model_dox/whisper.mdx to have it appear in the documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completed docstring of _mask_input_features and added it to model_dox/whisper.mdx

Comment on lines 503 to 506
padded_inputs["input_features"] = self._mask_input_features(
padded_inputs["input_features"],
attention_mask=padded_inputs.attention_mask[:, ::self.hop_length],
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure that we don't change the return type here! (previously was either a List or a np,array

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

padded_inputs["input_features"] was previously a List[np.array], and will be casted into tensor in the line 509.

Could we make padded_inputs["input_features"] always be np.array using or not using spec_augment? Since self._mask_input_features always accepts and returns np.array

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ArthurZucker , I think this PR is about to finish except this. What's your opinion here :)

@ArthurZucker
Copy link
Collaborator

And as mentioned by @samuelazran we should add at least one test, if possible comparing with the original masking (if openAI added it to their codebase) otherwise an integration test.

@bofenghuang
Copy link
Contributor Author

I was waiting for the validation of basic functions to continue the further work. Thanks for the comments! Will finish the rest

@bofenghuang
Copy link
Contributor Author

Hi @ArthurZucker, do you have any suggestions of how to differentiate train and validation/test sets in order to only augment train set ?

In my mind, we perhaps need to add SpecAugment related parameters to the __call__ function of WhisperFeatureExtractor, then update training example script here

def prepare_dataset(batch):
# process audio
sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
# process audio length
batch[model_input_name] = inputs.get(model_input_name)[0]
batch["input_length"] = len(sample["array"])
# process targets
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
batch["labels"] = tokenizer(input_str).input_ids
return batch
with training_args.main_process_first(desc="dataset map pre-processing"):
vectorized_datasets = raw_datasets.map(
prepare_dataset,
remove_columns=next(iter(raw_datasets.values())).column_names,
num_proc=data_args.preprocessing_num_workers,
desc="preprocess train dataset",
)

to

def prepare_dataset(batch, **kwargs):
    # process audio
    sample = batch[audio_column_name]
    inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"], **kwargs)
    # process audio length
    batch[model_input_name] = inputs.get(model_input_name)[0]
    batch["input_length"] = len(sample["array"])

    # process targets
    input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
    batch["labels"] = tokenizer(input_str).input_ids
    return batch


with training_args.main_process_first(desc="dataset map pre-processing"):
    vectorized_datasets = DatasetDict()

    if training_args.do_train:
        # NB: also add SpecAugment parameters to DataTrainingArguments
        vectorized_datasets["train"] = raw_datasets["train"].map(
            lambda example: prepare_dataset(
                example,
                apply_spec_augment=data_args.apply_spec_augment,
                mask_time_prob=data_args.mask_time_prob,
                mask_feature_prob=data_args.mask_feature_prob,
            ),
            remove_columns=next(iter(raw_datasets.values())).column_names,
            num_proc=data_args.preprocessing_num_workers,
            desc="preprocess train dataset",
        )

    if training_args.do_eval:
        vectorized_datasets["eval"] = raw_datasets["eval"].map(
            prepare_dataset,
            remove_columns=next(iter(raw_datasets.values())).column_names,
            num_proc=data_args.preprocessing_num_workers,
            desc="preprocess eval dataset",
        )

Also cc @sanchit-gandhi :)

@ArthurZucker
Copy link
Collaborator

I think I am in favor of just adding the do_spec_augment argument in the call of the feature extractor, which will default to False. The processing of training and validation should indeed be taken care of outside of the modelling.

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Jan 24, 2023

Hey @bofenghuang,

Really cool to see this new feature addition for SpecAug! Could well provide a nice boost for Whisper fine-tuning 🚀

Not sure I fully agree that we should add SpecAug to the feature extractor. IMO it's a regularisation technique that belongs in the modelling file which is in many ways analogous to dropout (we wouldn't ever add dropout to the feature extractor - this is a method that relates to the modelling code and thus we add it there).

Adding SpecAug to the feature extractor causes two problems:

  1. We pre-process our training dataset once at the start of training to obtain our log-Mel spectrograms. Using SpecAug in our feature extractor means that we generate a fixed set of masked features in these spectrograms. If we train for multiple epochs, we re-use our pre-processed dataset, and so have the same masked features for each epoch. This is analogous to dropping out the same nodes each time we do dropout -> the model will fit to these fixed SpecAug features, defeating the point of using this regularisation technique! What we actually want to do is mask different features in our spectrograms each time we use the data, i.e. mask in a stochastic way.
  2. We need different pre-processing logic for our train/eval sets. We need to 'turn on' SpecAug for the train set and 'turn off' SpecAug for the eval set.

Both of these problems are bypasses by putting SpecAug in the modelling file:

  1. We mask a different set of features at each forward pass in a stochastic way ('true' form of dropout)
  2. We only apply SpecAug when we train, which we can access with the attribute self.training. See:
    elif self.config.mask_time_prob > 0 and self.training:
    mask_time_indices = _compute_mask_indices(

So if it's ok with you I think we should modify this PR to move the SpecAug logic to the modelling file!

@ArthurZucker
Copy link
Collaborator

Oh I see thanks for thinking this far @sanchit-gandhi ! You are indeed right 👍🏻 Sorry @bofenghuang for missleading you 😅

@bofenghuang
Copy link
Contributor Author

Hi @sanchit-gandhi,

Thanks and totally agree with you! I've put it in the feature extractor just because it's a numpy version. I think we perhaps need to re-write it to pytorch if we want to have it in modeling? cc @ArthurZucker

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Jan 24, 2023

Think we can apply the same logic that we do in Wav2Vec2 and compute the mask using NumPy (no matmuls here, simply building a binary array of indices to mask/not mask in a stochastic way) and apply the mask in PyTorch to our tensors (hidden states).

So _compute_mask_indices is NumPy:

And _mask_hidden_states PyTorch:

You can probably copy these two methods directly from modeling_wav2vec2.py and apply the masking as required to the input_features in Whisper!

@bofenghuang bofenghuang mentioned this pull request Jan 25, 2023
12 tasks
@bofenghuang bofenghuang deleted the add-specaugment-to-whisper branch November 8, 2024 13:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants