Add BART DLM PyTorch pretraining example #18904
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Implements a pretraining example for BART (denoising language model). Big focus on getting the data denoising as close to the original fairseq as possible but instead of on the dataset level on the dataloader level.
Heavily inspired by the fairseq implementation and the FLAX implementation. (See
HF (Flax), fairseq, and current implementation
.) Looking for some feedback. Please seeQuestions/Uncertainties
.Some notes
Default values
The defaults are set to the given BART args. This differs from the Flax defaults in one respect, namely
poisson_lambda
, which is now set to3.5
instead of3.0
.HF (Flax), fairseq, and current implementation
There are some differences in implementation between fairseq, the HF FLAX example, and this PyTorch implementation.
argwhere
in the Flax examplein this position
is not the same as what is happening in fairseq. In fairseq
we check explicitly that the previous token was not a "full stop" (padding token) but in HF we just check whether the
current token is a full stop. In the current example I also explicitly check that the next token is not a full stop,
in case of padding. (However, in practice that should be a non-issue since all batches/samples should have the
same sequence length and there should not be any padding.)
up (bug report), so I have reimplemented that method so
that sentences in a sequence are still separated by a padding token, even after permutation.
In Fairseq, by default, only the first and last tokens are excluded and all others
are prone to masking. The HF implementation seems sensible so I follow that.
get_special_tokens_mask
includes thepadding token, though, so no need to add that separately.
add_insertion_noise
to work well with padded sequences. So the inserted noise may occurANYWHERE. It is unclear whether this is intended behavior.
Alternatively, we could implement all this processing on the dataset level and use
Dataset.map
. This has someadvantages:
... and disadvantages:
less true to the original fairseq implementation in
add_insertion_noise
same sample will always be processed the same. In a dataloader, that will not be the case because the processing
occurs on every iteration rather than once before training.
Questions/Uncertainties
add_insertion_noise
can insert noise anywhere (also in fairseq), which means that it will also overwrite specialtokens and that sequence don't necessarily end with a EOS token. Is that a problem?
datasets
team can chime in, too.Before submitting
Pull Request section?
Who can review?