-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sum loss instead of mean loss should be used if gradient accumulation step is larger than 1 when training a language model #24725
Comments
Hi @Atry Your description is correct. However, the loss logic is implemented in each model classes, and therefore it could not see multiple batches in a single model forward pass (and that's probably the main reason for which we just simply use The best and easy way to have a correct computation if to modify the trainer class to compute back, given the loss from model output, compute the sum of losses in a batch (by considering the sequence length, or total number of tokens that is meaningful - i.e. not padding token etc.), and send this new custom loss values to compute the gradients then accumulate it. |
Computing back the gradient would damage the precision if the gradient is in |
An idea is to switch all models to |
By the way there is another example of the issue in |
This would be a big breaking change, and would not be an option.
I would not think it will produce a big difference, if at the end, we still use some form of mean after we accumulate (sum) all the gradients (saying divided by the total number of non-padding tokens appear in all the batches in a gradient accumulation). When the loss is computed by sum in a batch, it actually requires specific work to perform to get back to the usual definition of that loss (say the average non-padding token loss) when we sum over all batches. (Here I only say non-padding token. But loss definition could get very complex depending on the tasks and the specific models) |
As studied in https://arxiv.org/abs/1711.00489, changing batch size would have a side effect to also change learning rate per sample (and learning rate per token) even when the learning rate per iteration is unchanged. However their analysis to their experiment result is non-sense. The actual explanation is that the side effect is just due to the mean loss. Sum loss would not lead to the side effect. |
If you are not happy with the loss computation inside the model, you can just not pass the As @ydshieh mentioned, a breaking change across all models of this magnitude is not possible. |
Good idea! I wonder if the |
The Trainer already does divide the loss by the number of gradient accumulation steps and there are tests in the CI to ensure training with batch size X and batch size X / g gradient accumulation steps g yield the same results. |
Suppose you have a dataset of two samples used in unsupervised learning against a decoder-only language model, sample 1 contains 11 tokens, sample 2 contains 101 tokens, when training at batch size 1 without padding, the In current
IMHO ideally the loss should be 0.82727 |
where does |
I believe in If you find a HF causal LM model that has a loss computation (in the model forward) that doesn't take care of the padding token, please let us know. 🙏 |
You are right. I misunderstood the implementation. I just updated my previous comments. Thank you! |
Thanks! As mentioned earlier:
But
|
I confronted the same issue. The gradient accumulation's result is much worse than using a large batch size (per device). The main reason that I assume is probably that the gradient accumulation macro-averages the loss scores, but they should be micro-averaged. I think this problem is so critical that it affects the result a lot for LMs (variable lengths across batches). Otherwise, the training result must be suboptimal. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
Not applicable, because this is a design issue, not a runtime error.
Who can help?
@sgugger, @ArthurZucker and @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Given gradient accumulation step 2, batch size 1, and a training set of 2 samples, where sample 1 contains 11 tokens and sample 2 contains 101 tokens, train a decoder-only model in unsupervised learning (first token in each sample is untrainable), then the gradient will be different from training on same dataset and model at gradient accumulation step 1, batch size 2.
The reason is that currently
transformers
use mean loss for most models (if not all), as a result, each token in sample 1 would produce 10 times larger gradient than that of each token in sample 2.Expected behavior
Settings of accumulation step 2 / batch size 1 should produce the same gradient as settings of accumulation step 1 / batch size 2.
The text was updated successfully, but these errors were encountered: